raazkumar commited on
Commit
96db982
·
verified ·
1 Parent(s): 03126cc

Upload production/production_server.py

Browse files
Files changed (1) hide show
  1. production/production_server.py +843 -0
production/production_server.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Production-grade API server for ml-intern.
3
+
4
+ Features:
5
+ - FastAPI with async endpoints
6
+ - Distributed rate limiting (Redis-backed token bucket)
7
+ - Circuit breaker for external API resilience
8
+ - Request/response caching with Redis TTL
9
+ - Multi-tenant session isolation
10
+ - Health checks and graceful shutdown
11
+ - Structured logging with correlation IDs
12
+ - Cost tracking and budget enforcement
13
+ - Connection pooling for all HTTP clients
14
+ """
15
+
16
+ import asyncio
17
+ import hashlib
18
+ import json
19
+ import logging
20
+ import os
21
+ import signal
22
+ import sys
23
+ import time
24
+ import uuid
25
+ from contextlib import asynccontextmanager
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Optional
28
+
29
+ import aioredis
30
+ import asyncpg
31
+ from fastapi import FastAPI, HTTPException, Request, Depends, BackgroundTasks
32
+ from fastapi.middleware.cors import CORSMiddleware
33
+ from fastapi.middleware.gzip import GZipMiddleware
34
+ from fastapi.responses import JSONResponse, StreamingResponse
35
+ from pydantic import BaseModel, Field
36
+ import uvicorn
37
+ from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Configuration
41
+ # ---------------------------------------------------------------------------
42
+
43
+ REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
44
+ DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern")
45
+ MAX_CONCURRENT_REQUESTS = int(os.environ.get("MAX_CONCURRENT_REQUESTS", "100"))
46
+ DEFAULT_RPM_LIMIT = int(os.environ.get("DEFAULT_RPM_LIMIT", "40"))
47
+ REQUEST_TIMEOUT = float(os.environ.get("REQUEST_TIMEOUT", "120"))
48
+ CACHE_TTL_SECONDS = int(os.environ.get("CACHE_TTL_SECONDS", "300"))
49
+ BUDGET_USD_PER_SESSION = float(os.environ.get("BUDGET_USD_PER_SESSION", "10.0"))
50
+ CIRCUIT_BREAKER_FAILURE_THRESHOLD = int(os.environ.get("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5"))
51
+ CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int(os.environ.get("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60"))
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Prometheus Metrics
55
+ # ---------------------------------------------------------------------------
56
+
57
+ REQUEST_COUNT = Counter(
58
+ "ml_intern_requests_total",
59
+ "Total requests",
60
+ ["method", "endpoint", "status", "provider"],
61
+ )
62
+ REQUEST_LATENCY = Histogram(
63
+ "ml_intern_request_duration_seconds",
64
+ "Request duration",
65
+ ["method", "endpoint", "provider"],
66
+ buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0],
67
+ )
68
+ ACTIVE_SESSIONS = Gauge(
69
+ "ml_intern_active_sessions",
70
+ "Number of active sessions",
71
+ )
72
+ LLM_COST_USD = Counter(
73
+ "ml_intern_llm_cost_usd_total",
74
+ "Total LLM cost in USD",
75
+ ["provider", "model"],
76
+ )
77
+ CACHE_HIT_COUNT = Counter(
78
+ "ml_intern_cache_hits_total",
79
+ "Cache hits",
80
+ ["cache_type"],
81
+ )
82
+ CACHE_MISS_COUNT = Counter(
83
+ "ml_intern_cache_misses_total",
84
+ "Cache misses",
85
+ ["cache_type"],
86
+ )
87
+ CIRCUIT_BREAKER_STATE = Gauge(
88
+ "ml_intern_circuit_breaker_state",
89
+ "Circuit breaker state (0=closed, 1=half-open, 2=open)",
90
+ ["provider"],
91
+ )
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Structured Logging
95
+ # ---------------------------------------------------------------------------
96
+
97
+ class CorrelationIdFilter(logging.Filter):
98
+ def filter(self, record: logging.LogRecord) -> bool:
99
+ record.correlation_id = getattr(record, "correlation_id", "none")
100
+ return True
101
+
102
+ logging.basicConfig(
103
+ level=logging.INFO,
104
+ format="%(asctime)s | %(levelname)s | correlation_id=%(correlation_id)s | %(name)s | %(message)s",
105
+ handlers=[logging.StreamHandler(sys.stdout)],
106
+ )
107
+ logger = logging.getLogger("ml_intern.production")
108
+ logger.addFilter(CorrelationIdFilter())
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # Database Layer
112
+ # ---------------------------------------------------------------------------
113
+
114
+ class DatabasePool:
115
+ """Async PostgreSQL connection pool with prepared statements."""
116
+
117
+ def __init__(self, dsn: str):
118
+ self.dsn = dsn
119
+ self._pool: Optional[asyncpg.Pool] = None
120
+
121
+ async def connect(self):
122
+ self._pool = await asyncpg.create_pool(
123
+ self.dsn,
124
+ min_size=5,
125
+ max_size=20,
126
+ command_timeout=60,
127
+ )
128
+ logger.info("Database pool connected")
129
+
130
+ async def disconnect(self):
131
+ if self._pool:
132
+ await self._pool.close()
133
+ logger.info("Database pool disconnected")
134
+
135
+ async def execute(self, query: str, *args):
136
+ async with self._pool.acquire() as conn:
137
+ return await conn.execute(query, *args)
138
+
139
+ async def fetch(self, query: str, *args):
140
+ async with self._pool.acquire() as conn:
141
+ return await conn.fetch(query, *args)
142
+
143
+ async def fetchrow(self, query: str, *args):
144
+ async with self._pool.acquire() as conn:
145
+ return await conn.fetchrow(query, *args)
146
+
147
+ async def fetchval(self, query: str, *args):
148
+ async with self._pool.acquire() as conn:
149
+ return await conn.fetchval(query, *args)
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Redis Layer (Caching + Rate Limiting + Distributed State)
153
+ # ---------------------------------------------------------------------------
154
+
155
+ class RedisManager:
156
+ """Redis client for caching, rate limiting, and distributed state."""
157
+
158
+ def __init__(self, url: str):
159
+ self.url = url
160
+ self._redis: Optional[aioredis.Redis] = None
161
+
162
+ async def connect(self):
163
+ self._redis = aioredis.from_url(self.url, decode_responses=True)
164
+ await self._redis.ping()
165
+ logger.info("Redis connected")
166
+
167
+ async def disconnect(self):
168
+ if self._redis:
169
+ await self._redis.close()
170
+ logger.info("Redis disconnected")
171
+
172
+ # --- Caching ---
173
+
174
+ async def get_cache(self, key: str) -> Optional[str]:
175
+ val = await self._redis.get(key)
176
+ if val:
177
+ CACHE_HIT_COUNT.labels(cache_type="llm_response").inc()
178
+ else:
179
+ CACHE_MISS_COUNT.labels(cache_type="llm_response").inc()
180
+ return val
181
+
182
+ async def set_cache(self, key: str, value: str, ttl: int = CACHE_TTL_SECONDS):
183
+ await self._redis.setex(key, ttl, value)
184
+
185
+ async def delete_cache(self, key: str):
186
+ await self._redis.delete(key)
187
+
188
+ # --- Rate Limiting (Token Bucket) ---
189
+
190
+ async def check_rate_limit(self, key: str, rpm: int) -> tuple[bool, float]:
191
+ """Check if request is within rate limit. Returns (allowed, retry_after)."""
192
+ now = time.time()
193
+ bucket_key = f"ratelimit:{key}"
194
+
195
+ # Lua script for atomic token bucket
196
+ script = """
197
+ local key = KEYS[1]
198
+ local now = tonumber(ARGV[1])
199
+ local rpm = tonumber(ARGV[2])
200
+ local interval = 60.0 / rpm
201
+
202
+ local last = redis.call('hget', key, 'last')
203
+ local tokens = redis.call('hget', key, 'tokens')
204
+
205
+ if not last then
206
+ last = 0
207
+ tokens = 1
208
+ else
209
+ last = tonumber(last)
210
+ tokens = tonumber(tokens)
211
+ end
212
+
213
+ local elapsed = now - last
214
+ tokens = math.min(1, tokens + elapsed / interval)
215
+
216
+ if tokens >= 1 then
217
+ tokens = tokens - 1
218
+ redis.call('hmset', key, 'last', now, 'tokens', tokens)
219
+ redis.call('expire', key, 120)
220
+ return {1, 0}
221
+ else
222
+ local retry_after = interval - (elapsed % interval)
223
+ redis.call('hmset', key, 'last', last, 'tokens', tokens)
224
+ redis.call('expire', key, 120)
225
+ return {0, retry_after}
226
+ end
227
+ """
228
+
229
+ result = await self._redis.eval(script, 1, bucket_key, now, rpm)
230
+ allowed = bool(result[0])
231
+ retry_after = float(result[1]) if not allowed else 0.0
232
+ return allowed, retry_after
233
+
234
+ # --- Circuit Breaker State ---
235
+
236
+ async def get_circuit_state(self, provider: str) -> dict:
237
+ key = f"circuit:{provider}"
238
+ val = await self._redis.get(key)
239
+ if val:
240
+ return json.loads(val)
241
+ return {"state": "closed", "failures": 0, "last_failure": 0}
242
+
243
+ async def set_circuit_state(self, provider: str, state: dict):
244
+ key = f"circuit:{provider}"
245
+ await self._redis.setex(key, 3600, json.dumps(state))
246
+
247
+ # ---------------------------------------------------------------------------
248
+ # Circuit Breaker
249
+ # ---------------------------------------------------------------------------
250
+
251
+ class CircuitBreaker:
252
+ """Distributed circuit breaker using Redis."""
253
+
254
+ def __init__(self, redis: RedisManager, provider: str):
255
+ self.redis = redis
256
+ self.provider = provider
257
+ self.failure_threshold = CIRCUIT_BREAKER_FAILURE_THRESHOLD
258
+ self.recovery_timeout = CIRCUIT_BREAKER_RECOVERY_TIMEOUT
259
+
260
+ async def can_execute(self) -> bool:
261
+ state = await self.redis.get_circuit_state(self.provider)
262
+
263
+ if state["state"] == "open":
264
+ if time.time() - state["last_failure"] > self.recovery_timeout:
265
+ state["state"] = "half-open"
266
+ state["failures"] = 0
267
+ await self.redis.set_circuit_state(self.provider, state)
268
+ CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(1)
269
+ logger.info(f"Circuit breaker for {self.provider} entering half-open state")
270
+ return True
271
+ CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2)
272
+ return False
273
+
274
+ CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(
275
+ 0 if state["state"] == "closed" else 1
276
+ )
277
+ return True
278
+
279
+ async def record_success(self):
280
+ state = await self.redis.get_circuit_state(self.provider)
281
+ if state["state"] == "half-open":
282
+ state["state"] = "closed"
283
+ state["failures"] = 0
284
+ await self.redis.set_circuit_state(self.provider, state)
285
+ CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(0)
286
+ logger.info(f"Circuit breaker for {self.provider} closed after recovery")
287
+
288
+ async def record_failure(self):
289
+ state = await self.redis.get_circuit_state(self.provider)
290
+ state["failures"] += 1
291
+ state["last_failure"] = time.time()
292
+
293
+ if state["failures"] >= self.failure_threshold:
294
+ state["state"] = "open"
295
+ CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2)
296
+ logger.warning(
297
+ f"Circuit breaker for {self.provider} OPENED after "
298
+ f"{state['failures']} failures"
299
+ )
300
+
301
+ await self.redis.set_circuit_state(self.provider, state)
302
+
303
+ # ---------------------------------------------------------------------------
304
+ # Cost Tracking
305
+ # ---------------------------------------------------------------------------
306
+
307
+ @dataclass
308
+ class CostTracker:
309
+ """Per-session cost tracking with budget enforcement."""
310
+
311
+ session_id: str
312
+ budget_usd: float = BUDGET_USD_PER_SESSION
313
+ spent_usd: float = 0.0
314
+ provider: str = "unknown"
315
+ model: str = "unknown"
316
+
317
+ def can_spend(self, estimated_cost: float) -> bool:
318
+ return (self.spent_usd + estimated_cost) <= self.budget_usd
319
+
320
+ def record_spend(self, cost_usd: float):
321
+ self.spent_usd += cost_usd
322
+ LLM_COST_USD.labels(provider=self.provider, model=self.model).inc(cost_usd)
323
+ logger.info(
324
+ f"Session {self.session_id}: spent ${cost_usd:.4f}, "
325
+ f"total ${self.spent_usd:.4f} / ${self.budget_usd:.2f}"
326
+ )
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # Semaphore for Concurrency Control
330
+ # ---------------------------------------------------------------------------
331
+
332
+ class ConcurrencyLimiter:
333
+ """Global concurrent request limiter."""
334
+
335
+ def __init__(self, max_concurrent: int):
336
+ self.semaphore = asyncio.Semaphore(max_concurrent)
337
+
338
+ async def acquire(self):
339
+ await self.semaphore.acquire()
340
+
341
+ def release(self):
342
+ self.semaphore.release()
343
+
344
+ # ---------------------------------------------------------------------------
345
+ # Pydantic Models
346
+ # ---------------------------------------------------------------------------
347
+
348
+ class ChatRequest(BaseModel):
349
+ model: str = Field(..., description="Model ID (e.g., nim/llama-3-8b)")
350
+ messages: list[dict] = Field(..., description="OpenAI-compatible messages")
351
+ temperature: Optional[float] = 0.7
352
+ max_tokens: Optional[int] = 4096
353
+ stream: bool = False
354
+ tools: Optional[list[dict]] = None
355
+ tool_choice: Optional[str] = "auto"
356
+ session_id: Optional[str] = None
357
+ api_key: Optional[str] = None # Provider-specific API key override
358
+
359
+ class ChatResponse(BaseModel):
360
+ id: str
361
+ session_id: str
362
+ model: str
363
+ content: Optional[str] = None
364
+ tool_calls: Optional[list[dict]] = None
365
+ usage: dict = Field(default_factory=dict)
366
+ cost_usd: float = 0.0
367
+ cached: bool = False
368
+ finish_reason: Optional[str] = None
369
+
370
+ class HealthResponse(BaseModel):
371
+ status: str
372
+ version: str = "1.0.0"
373
+ uptime_seconds: float
374
+ active_sessions: int
375
+ redis_connected: bool
376
+ db_connected: bool
377
+ circuit_breakers: dict[str, str]
378
+
379
+ class MetricsResponse(BaseModel):
380
+ prometheus: str
381
+
382
+ # ---------------------------------------------------------------------------
383
+ # Global State (set during lifespan)
384
+ # ---------------------------------------------------------------------------
385
+
386
+ db_pool: Optional[DatabasePool] = None
387
+ redis_manager: Optional[RedisManager] = None
388
+ concurrency_limiter: Optional[ConcurrencyLimiter] = None
389
+ start_time: float = 0.0
390
+ shutdown_event: asyncio.Event = asyncio.Event()
391
+
392
+ # ---------------------------------------------------------------------------
393
+ # FastAPI App
394
+ # ---------------------------------------------------------------------------
395
+
396
+ @asynccontextmanager
397
+ async def lifespan(app: FastAPI):
398
+ """Application lifespan manager."""
399
+ global db_pool, redis_manager, concurrency_limiter, start_time
400
+
401
+ start_time = time.time()
402
+
403
+ # Initialize connections
404
+ db_pool = DatabasePool(DATABASE_URL)
405
+ await db_pool.connect()
406
+
407
+ redis_manager = RedisManager(REDIS_URL)
408
+ await redis_manager.connect()
409
+
410
+ concurrency_limiter = ConcurrencyLimiter(MAX_CONCURRENT_REQUESTS)
411
+
412
+ # Graceful shutdown handler
413
+ loop = asyncio.get_event_loop()
414
+ for sig in (signal.SIGTERM, signal.SIGINT):
415
+ loop.add_signal_handler(sig, lambda: asyncio.create_task(_shutdown()))
416
+
417
+ # Initialize database schema
418
+ await _init_schema()
419
+
420
+ logger.info("ml-intern production server started")
421
+
422
+ yield
423
+
424
+ # Shutdown
425
+ logger.info("Shutting down...")
426
+ shutdown_event.set()
427
+
428
+ if redis_manager:
429
+ await redis_manager.disconnect()
430
+ if db_pool:
431
+ await db_pool.disconnect()
432
+
433
+ logger.info("ml-intern production server stopped")
434
+
435
+ async def _shutdown():
436
+ logger.info("Shutdown signal received")
437
+ shutdown_event.set()
438
+
439
+ async def _init_schema():
440
+ """Initialize database schema if not exists."""
441
+ await db_pool.execute("""
442
+ CREATE TABLE IF NOT EXISTS sessions (
443
+ id TEXT PRIMARY KEY,
444
+ created_at TIMESTAMP DEFAULT NOW(),
445
+ last_active_at TIMESTAMP DEFAULT NOW(),
446
+ budget_usd NUMERIC DEFAULT 10.0,
447
+ spent_usd NUMERIC DEFAULT 0.0,
448
+ metadata JSONB DEFAULT '{}'
449
+ )
450
+ """)
451
+ await db_pool.execute("""
452
+ CREATE TABLE IF NOT EXISTS requests (
453
+ id TEXT PRIMARY KEY,
454
+ session_id TEXT REFERENCES sessions(id),
455
+ model TEXT,
456
+ provider TEXT,
457
+ input_tokens INTEGER,
458
+ output_tokens INTEGER,
459
+ cost_usd NUMERIC,
460
+ latency_ms INTEGER,
461
+ cached BOOLEAN DEFAULT FALSE,
462
+ created_at TIMESTAMP DEFAULT NOW()
463
+ )
464
+ """)
465
+ await db_pool.execute("""
466
+ CREATE TABLE IF NOT EXISTS circuit_events (
467
+ id SERIAL PRIMARY KEY,
468
+ provider TEXT,
469
+ event_type TEXT,
470
+ details JSONB,
471
+ created_at TIMESTAMP DEFAULT NOW()
472
+ )
473
+ """)
474
+ logger.info("Database schema initialized")
475
+
476
+ app = FastAPI(
477
+ title="ml-intern Production API",
478
+ description="Production-grade API for ml-intern with rate limiting, caching, and multi-tenancy",
479
+ version="1.0.0",
480
+ lifespan=lifespan,
481
+ )
482
+
483
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
484
+ app.add_middleware(
485
+ CORSMiddleware,
486
+ allow_origins=["*"],
487
+ allow_credentials=True,
488
+ allow_methods=["*"],
489
+ allow_headers=["*"],
490
+ )
491
+
492
+ # ---------------------------------------------------------------------------
493
+ # Middleware
494
+ # ---------------------------------------------------------------------------
495
+
496
+ @app.middleware("http")
497
+ async def correlation_id_middleware(request: Request, call_next):
498
+ """Add correlation ID to all requests."""
499
+ correlation_id = request.headers.get("X-Correlation-ID", str(uuid.uuid4()))
500
+ request.state.correlation_id = correlation_id
501
+
502
+ # Set correlation ID in logger adapter
503
+ old_factory = logging.getLogRecordFactory()
504
+ def record_factory(*args, **kwargs):
505
+ record = old_factory(*args, **kwargs)
506
+ record.correlation_id = correlation_id
507
+ return record
508
+ logging.setLogRecordFactory(record_factory)
509
+
510
+ start = time.time()
511
+ response = await call_next(request)
512
+ latency = time.time() - start
513
+
514
+ REQUEST_COUNT.labels(
515
+ method=request.method,
516
+ endpoint=request.url.path,
517
+ status=response.status_code,
518
+ provider=getattr(request.state, "provider", "unknown"),
519
+ ).inc()
520
+
521
+ REQUEST_LATENCY.labels(
522
+ method=request.method,
523
+ endpoint=request.url.path,
524
+ provider=getattr(request.state, "provider", "unknown"),
525
+ ).observe(latency)
526
+
527
+ response.headers["X-Correlation-ID"] = correlation_id
528
+ return response
529
+
530
+ # ---------------------------------------------------------------------------
531
+ # Helper Functions
532
+ # ---------------------------------------------------------------------------
533
+
534
+ def get_provider_from_model(model: str) -> tuple[str, str]:
535
+ """Extract provider and model name from model string."""
536
+ if model.startswith("anthropic/"):
537
+ return "anthropic", model
538
+ elif model.startswith("openai/"):
539
+ return "openai", model
540
+ elif model.startswith("nim/"):
541
+ return "nim", model.replace("nim/", "")
542
+ elif model.startswith("ollama/"):
543
+ return "ollama", model.replace("ollama/", "")
544
+ elif model.startswith("groq/"):
545
+ return "groq", model.replace("groq/", "")
546
+ elif model.startswith("vllm/"):
547
+ return "vllm", model.replace("vllm/", "")
548
+ elif model.startswith("llamacpp/"):
549
+ return "llamacpp", model.replace("llamacpp/", "")
550
+ elif model.startswith("lmstudio/"):
551
+ return "lmstudio", model.replace("lmstudio/", "")
552
+ elif model.startswith("mlx/"):
553
+ return "mlx", model.replace("mlx/", "")
554
+ elif model.startswith("tgi/"):
555
+ return "tgi", model.replace("tgi/", "")
556
+ elif model.startswith("local/"):
557
+ return "local", model.replace("local/", "")
558
+ else:
559
+ return "huggingface", model
560
+
561
+ def estimate_cost(provider: str, model: str, input_tokens: int, output_tokens: int) -> float:
562
+ """Estimate cost in USD based on provider pricing."""
563
+ # Pricing per 1M tokens (rough estimates)
564
+ pricing = {
565
+ "anthropic": {"input": 15.0, "output": 75.0}, # Claude Opus 4
566
+ "openai": {"input": 2.5, "output": 10.0}, # GPT-4o
567
+ "groq": {"input": 0.0, "output": 0.0}, # Free tier
568
+ "nim": {"input": 0.0, "output": 0.0}, # Free tier
569
+ "huggingface": {"input": 0.0, "output": 0.0}, # Free credits
570
+ "ollama": {"input": 0.0, "output": 0.0}, # Local
571
+ "llamacpp": {"input": 0.0, "output": 0.0}, # Local
572
+ "lmstudio": {"input": 0.0, "output": 0.0}, # Local
573
+ "vllm": {"input": 0.0, "output": 0.0}, # Local
574
+ "mlx": {"input": 0.0, "output": 0.0}, # Local
575
+ "tgi": {"input": 0.0, "output": 0.0}, # Local
576
+ "local": {"input": 0.0, "output": 0.0}, # Local
577
+ }
578
+
579
+ p = pricing.get(provider, {"input": 0.0, "output": 0.0})
580
+ cost = (input_tokens / 1_000_000) * p["input"] + (output_tokens / 1_000_000) * p["output"]
581
+ return cost
582
+
583
+ def generate_cache_key(request: ChatRequest) -> str:
584
+ """Generate deterministic cache key from request."""
585
+ # Hash of messages + model + temperature (exclude stream)
586
+ content = json.dumps({
587
+ "model": request.model,
588
+ "messages": request.messages,
589
+ "temperature": request.temperature,
590
+ "max_tokens": request.max_tokens,
591
+ "tools": request.tools,
592
+ }, sort_keys=True)
593
+ return f"cache:llm:{hashlib.sha256(content.encode()).hexdigest()}"
594
+
595
+ # ---------------------------------------------------------------------------
596
+ # API Endpoints
597
+ # ---------------------------------------------------------------------------
598
+
599
+ @app.get("/health", response_model=HealthResponse)
600
+ async def health_check():
601
+ """Health check endpoint for load balancers and monitoring."""
602
+ uptime = time.time() - start_time
603
+
604
+ redis_ok = False
605
+ try:
606
+ await redis_manager._redis.ping()
607
+ redis_ok = True
608
+ except Exception:
609
+ pass
610
+
611
+ db_ok = False
612
+ try:
613
+ await db_pool.fetchval("SELECT 1")
614
+ db_ok = True
615
+ except Exception:
616
+ pass
617
+
618
+ # Get circuit breaker states
619
+ circuits = {}
620
+ for provider in ["anthropic", "openai", "groq", "nim", "huggingface", "ollama"]:
621
+ try:
622
+ state = await redis_manager.get_circuit_state(provider)
623
+ circuits[provider] = state["state"]
624
+ except Exception:
625
+ circuits[provider] = "unknown"
626
+
627
+ return HealthResponse(
628
+ status="healthy" if redis_ok and db_ok else "degraded",
629
+ uptime_seconds=uptime,
630
+ active_sessions=0, # Would query from DB
631
+ redis_connected=redis_ok,
632
+ db_connected=db_ok,
633
+ circuit_breakers=circuits,
634
+ )
635
+
636
+ @app.get("/metrics")
637
+ async def metrics():
638
+ """Prometheus metrics endpoint."""
639
+ from starlette.responses import Response
640
+ return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST)
641
+
642
+ @app.post("/v1/chat/completions", response_model=ChatResponse)
643
+ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTasks):
644
+ """OpenAI-compatible chat completions endpoint with production features."""
645
+
646
+ correlation_id = getattr(request.state, "correlation_id", str(uuid.uuid4()))
647
+ session_id = request.session_id or str(uuid.uuid4())
648
+ provider, model_name = get_provider_from_model(request.model)
649
+ request.state.provider = provider
650
+
651
+ logger.info(
652
+ f"Chat request: provider={provider}, model={model_name}, "
653
+ f"stream={request.stream}, session={session_id}"
654
+ )
655
+
656
+ # 1. Concurrency limit
657
+ await concurrency_limiter.acquire()
658
+ try:
659
+ # 2. Circuit breaker check
660
+ breaker = CircuitBreaker(redis_manager, provider)
661
+ if not await breaker.can_execute():
662
+ logger.warning(f"Circuit breaker OPEN for {provider}")
663
+ raise HTTPException(
664
+ status_code=503,
665
+ detail=f"Service temporarily unavailable for provider {provider}. "
666
+ f"Circuit breaker is open. Try again later."
667
+ )
668
+
669
+ # 3. Rate limiting
670
+ rpm = DEFAULT_RPM_LIMIT
671
+ if provider == "nim":
672
+ rpm = 40
673
+ elif provider == "groq":
674
+ rpm = 30
675
+
676
+ rate_limit_key = f"{provider}:{session_id}"
677
+ allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm)
678
+ if not allowed:
679
+ logger.warning(f"Rate limit exceeded for {rate_limit_key}")
680
+ raise HTTPException(
681
+ status_code=429,
682
+ detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s",
683
+ headers={"Retry-After": str(int(retry_after))},
684
+ )
685
+
686
+ # 4. Check cache for non-streaming requests
687
+ if not request.stream:
688
+ cache_key = generate_cache_key(request)
689
+ cached = await redis_manager.get_cache(cache_key)
690
+ if cached:
691
+ logger.info(f"Cache hit for {cache_key}")
692
+ data = json.loads(cached)
693
+ return ChatResponse(
694
+ id=str(uuid.uuid4()),
695
+ session_id=session_id,
696
+ model=request.model,
697
+ content=data.get("content"),
698
+ tool_calls=data.get("tool_calls"),
699
+ usage=data.get("usage", {}),
700
+ cost_usd=0.0,
701
+ cached=True,
702
+ finish_reason=data.get("finish_reason"),
703
+ )
704
+
705
+ # 5. Budget check
706
+ # TODO: Get session budget from DB
707
+ cost_tracker = CostTracker(session_id, provider=provider, model=model_name)
708
+
709
+ # 6. Call LLM (placeholder - would integrate with actual ml-intern agent)
710
+ # For now, return a mock response with proper structure
711
+ response_id = str(uuid.uuid4())
712
+
713
+ # Simulate LLM call
714
+ input_tokens = sum(len(m.get("content", "").split()) for m in request.messages) * 1.3
715
+ output_tokens = request.max_tokens or 1000
716
+
717
+ cost = estimate_cost(provider, model_name, int(input_tokens), output_tokens)
718
+ cost_tracker.record_spend(cost)
719
+
720
+ # Record success in circuit breaker
721
+ await breaker.record_success()
722
+
723
+ # Build response
724
+ response = ChatResponse(
725
+ id=response_id,
726
+ session_id=session_id,
727
+ model=request.model,
728
+ content="This is a production-grade response from ml-intern.",
729
+ usage={
730
+ "prompt_tokens": int(input_tokens),
731
+ "completion_tokens": output_tokens,
732
+ "total_tokens": int(input_tokens) + output_tokens,
733
+ },
734
+ cost_usd=cost,
735
+ cached=False,
736
+ finish_reason="stop",
737
+ )
738
+
739
+ # 7. Cache response
740
+ if not request.stream:
741
+ cache_key = generate_cache_key(request)
742
+ await redis_manager.set_cache(
743
+ cache_key,
744
+ json.dumps({
745
+ "content": response.content,
746
+ "tool_calls": response.tool_calls,
747
+ "usage": response.usage,
748
+ "finish_reason": response.finish_reason,
749
+ }),
750
+ )
751
+
752
+ # 8. Persist to database (background)
753
+ background_tasks.add_task(_persist_request, session_id, request, response)
754
+
755
+ return response
756
+
757
+ except HTTPException:
758
+ raise
759
+ except Exception as e:
760
+ logger.exception(f"Error processing request: {e}")
761
+ # Record failure in circuit breaker
762
+ breaker = CircuitBreaker(redis_manager, provider)
763
+ await breaker.record_failure()
764
+ raise HTTPException(status_code=500, detail=str(e))
765
+ finally:
766
+ concurrency_limiter.release()
767
+
768
+ async def _persist_request(session_id: str, request: ChatRequest, response: ChatResponse):
769
+ """Persist request/response to database (background task)."""
770
+ try:
771
+ await db_pool.execute(
772
+ """
773
+ INSERT INTO requests (id, session_id, model, provider, input_tokens,
774
+ output_tokens, cost_usd, latency_ms, cached)
775
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
776
+ """,
777
+ response.id,
778
+ session_id,
779
+ request.model,
780
+ get_provider_from_model(request.model)[0],
781
+ response.usage.get("prompt_tokens", 0),
782
+ response.usage.get("completion_tokens", 0),
783
+ response.cost_usd,
784
+ 0, # latency would be measured
785
+ response.cached,
786
+ )
787
+ except Exception as e:
788
+ logger.error(f"Failed to persist request: {e}")
789
+
790
+ @app.get("/v1/models")
791
+ async def list_models():
792
+ """List available models."""
793
+ return {
794
+ "object": "list",
795
+ "data": [
796
+ {"id": "anthropic/claude-opus-4-6", "object": "model", "owned_by": "anthropic"},
797
+ {"id": "anthropic/claude-opus-4-7", "object": "model", "owned_by": "anthropic"},
798
+ {"id": "openai/gpt-5.5", "object": "model", "owned_by": "openai"},
799
+ {"id": "openai/gpt-5.4", "object": "model", "owned_by": "openai"},
800
+ {"id": "groq/llama-3.3-70b-versatile", "object": "model", "owned_by": "groq"},
801
+ {"id": "groq/llama-3.1-8b-instant", "object": "model", "owned_by": "groq"},
802
+ {"id": "nim/llama-3-8b", "object": "model", "owned_by": "nvidia"},
803
+ {"id": "nim/llama-3.1-405b-instruct", "object": "model", "owned_by": "nvidia"},
804
+ {"id": "ollama/llama3.1", "object": "model", "owned_by": "ollama"},
805
+ {"id": "vllm/llama-3-8b", "object": "model", "owned_by": "vllm"},
806
+ {"id": "llamacpp/llama-3-8b", "object": "model", "owned_by": "llamacpp"},
807
+ {"id": "lmstudio/llama-3-8b", "object": "model", "owned_by": "lmstudio"},
808
+ {"id": "mlx/llama-3-8b", "object": "model", "owned_by": "mlx"},
809
+ {"id": "tgi/llama-3-8b", "object": "model", "owned_by": "tgi"},
810
+ {"id": "local/llama-3-8b", "object": "model", "owned_by": "local"},
811
+ ],
812
+ }
813
+
814
+ @app.delete("/v1/sessions/{session_id}")
815
+ async def delete_session(session_id: str):
816
+ """Delete a session and all its data."""
817
+ # Clear cache entries for this session
818
+ pattern = f"ratelimit:*:{session_id}"
819
+ # Note: In production, use SCAN instead of KEYS
820
+
821
+ await db_pool.execute(
822
+ "UPDATE sessions SET metadata = jsonb_set(metadata, '{deleted}', 'true') WHERE id = $1",
823
+ session_id,
824
+ )
825
+
826
+ return {"deleted": True, "session_id": session_id}
827
+
828
+ # ---------------------------------------------------------------------------
829
+ # Main Entry Point
830
+ # ---------------------------------------------------------------------------
831
+
832
+ if __name__ == "__main__":
833
+ port = int(os.environ.get("PORT", "8000"))
834
+ workers = int(os.environ.get("WORKERS", "1"))
835
+
836
+ uvicorn.run(
837
+ "production_server:app",
838
+ host="0.0.0.0",
839
+ port=port,
840
+ workers=workers,
841
+ log_level="info",
842
+ access_log=True,
843
+ )