raazkumar commited on
Commit
8296605
·
verified ·
1 Parent(s): 03cc10d

Upload production/production_server.py

Browse files
Files changed (1) hide show
  1. production/production_server.py +337 -248
production/production_server.py CHANGED
@@ -5,12 +5,14 @@ Features:
5
  - FastAPI with async endpoints
6
  - Distributed rate limiting (Redis-backed token bucket)
7
  - Circuit breaker for external API resilience
 
8
  - Request/response caching with Redis TTL
9
  - Multi-tenant session isolation
10
  - Health checks and graceful shutdown
11
  - Structured logging with correlation IDs
12
  - Cost tracking and budget enforcement
13
  - Connection pooling for all HTTP clients
 
14
  """
15
 
16
  import asyncio
@@ -26,22 +28,23 @@ from contextlib import asynccontextmanager
26
  from dataclasses import dataclass, field
27
  from typing import Any, Optional
28
 
29
- import aioredis
30
  import asyncpg
31
  from fastapi import FastAPI, HTTPException, Request, Depends, BackgroundTasks
32
  from fastapi.middleware.cors import CORSMiddleware
33
  from fastapi.middleware.gzip import GZipMiddleware
34
- from fastapi.responses import JSONResponse, StreamingResponse
35
  from pydantic import BaseModel, Field
36
  import uvicorn
37
  from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
 
 
38
 
39
  # ---------------------------------------------------------------------------
40
  # Configuration
41
  # ---------------------------------------------------------------------------
42
 
43
  REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
44
- DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://localhost/ml_intern")
45
  MAX_CONCURRENT_REQUESTS = int(os.environ.get("MAX_CONCURRENT_REQUESTS", "100"))
46
  DEFAULT_RPM_LIMIT = int(os.environ.get("DEFAULT_RPM_LIMIT", "40"))
47
  REQUEST_TIMEOUT = float(os.environ.get("REQUEST_TIMEOUT", "120"))
@@ -50,6 +53,20 @@ BUDGET_USD_PER_SESSION = float(os.environ.get("BUDGET_USD_PER_SESSION", "10.0"))
50
  CIRCUIT_BREAKER_FAILURE_THRESHOLD = int(os.environ.get("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5"))
51
  CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int(os.environ.get("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60"))
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # ---------------------------------------------------------------------------
54
  # Prometheus Metrics
55
  # ---------------------------------------------------------------------------
@@ -89,22 +106,28 @@ CIRCUIT_BREAKER_STATE = Gauge(
89
  "Circuit breaker state (0=closed, 1=half-open, 2=open)",
90
  ["provider"],
91
  )
 
 
 
 
 
92
 
93
  # ---------------------------------------------------------------------------
94
  # Structured Logging
95
  # ---------------------------------------------------------------------------
96
 
97
- class CorrelationIdFilter(logging.Filter):
98
- def filter(self, record: logging.LogRecord) -> bool:
99
- record.correlation_id = getattr(record, "correlation_id", "none")
100
- return True
101
-
102
  logging.basicConfig(
103
  level=logging.INFO,
104
  format="%(asctime)s | %(levelname)s | correlation_id=%(correlation_id)s | %(name)s | %(message)s",
105
  handlers=[logging.StreamHandler(sys.stdout)],
106
  )
107
  logger = logging.getLogger("ml_intern.production")
 
 
 
 
 
 
108
  logger.addFilter(CorrelationIdFilter())
109
 
110
  # ---------------------------------------------------------------------------
@@ -112,17 +135,18 @@ logger.addFilter(CorrelationIdFilter())
112
  # ---------------------------------------------------------------------------
113
 
114
  class DatabasePool:
115
- """Async PostgreSQL connection pool with prepared statements."""
116
-
117
  def __init__(self, dsn: str):
118
  self.dsn = dsn
119
  self._pool: Optional[asyncpg.Pool] = None
120
 
121
  async def connect(self):
 
 
 
122
  self._pool = await asyncpg.create_pool(
123
  self.dsn,
124
- min_size=5,
125
- max_size=20,
126
  command_timeout=60,
127
  )
128
  logger.info("Database pool connected")
@@ -133,28 +157,28 @@ class DatabasePool:
133
  logger.info("Database pool disconnected")
134
 
135
  async def execute(self, query: str, *args):
 
 
136
  async with self._pool.acquire() as conn:
137
  return await conn.execute(query, *args)
138
 
139
  async def fetch(self, query: str, *args):
 
 
140
  async with self._pool.acquire() as conn:
141
  return await conn.fetch(query, *args)
142
 
143
- async def fetchrow(self, query: str, *args):
144
- async with self._pool.acquire() as conn:
145
- return await conn.fetchrow(query, *args)
146
-
147
  async def fetchval(self, query: str, *args):
 
 
148
  async with self._pool.acquire() as conn:
149
  return await conn.fetchval(query, *args)
150
 
151
  # ---------------------------------------------------------------------------
152
- # Redis Layer (Caching + Rate Limiting + Distributed State)
153
  # ---------------------------------------------------------------------------
154
 
155
  class RedisManager:
156
- """Redis client for caching, rate limiting, and distributed state."""
157
-
158
  def __init__(self, url: str):
159
  self.url = url
160
  self._redis: Optional[aioredis.Redis] = None
@@ -169,8 +193,6 @@ class RedisManager:
169
  await self._redis.close()
170
  logger.info("Redis disconnected")
171
 
172
- # --- Caching ---
173
-
174
  async def get_cache(self, key: str) -> Optional[str]:
175
  val = await self._redis.get(key)
176
  if val:
@@ -182,26 +204,16 @@ class RedisManager:
182
  async def set_cache(self, key: str, value: str, ttl: int = CACHE_TTL_SECONDS):
183
  await self._redis.setex(key, ttl, value)
184
 
185
- async def delete_cache(self, key: str):
186
- await self._redis.delete(key)
187
-
188
- # --- Rate Limiting (Token Bucket) ---
189
-
190
  async def check_rate_limit(self, key: str, rpm: int) -> tuple[bool, float]:
191
- """Check if request is within rate limit. Returns (allowed, retry_after)."""
192
  now = time.time()
193
  bucket_key = f"ratelimit:{key}"
194
-
195
- # Lua script for atomic token bucket
196
  script = """
197
  local key = KEYS[1]
198
  local now = tonumber(ARGV[1])
199
  local rpm = tonumber(ARGV[2])
200
  local interval = 60.0 / rpm
201
-
202
  local last = redis.call('hget', key, 'last')
203
  local tokens = redis.call('hget', key, 'tokens')
204
-
205
  if not last then
206
  last = 0
207
  tokens = 1
@@ -209,10 +221,8 @@ class RedisManager:
209
  last = tonumber(last)
210
  tokens = tonumber(tokens)
211
  end
212
-
213
  local elapsed = now - last
214
  tokens = math.min(1, tokens + elapsed / interval)
215
-
216
  if tokens >= 1 then
217
  tokens = tokens - 1
218
  redis.call('hmset', key, 'last', now, 'tokens', tokens)
@@ -225,14 +235,11 @@ class RedisManager:
225
  return {0, retry_after}
226
  end
227
  """
228
-
229
  result = await self._redis.eval(script, 1, bucket_key, now, rpm)
230
  allowed = bool(result[0])
231
  retry_after = float(result[1]) if not allowed else 0.0
232
  return allowed, retry_after
233
 
234
- # --- Circuit Breaker State ---
235
-
236
  async def get_circuit_state(self, provider: str) -> dict:
237
  key = f"circuit:{provider}"
238
  val = await self._redis.get(key)
@@ -249,8 +256,6 @@ class RedisManager:
249
  # ---------------------------------------------------------------------------
250
 
251
  class CircuitBreaker:
252
- """Distributed circuit breaker using Redis."""
253
-
254
  def __init__(self, redis: RedisManager, provider: str):
255
  self.redis = redis
256
  self.provider = provider
@@ -259,21 +264,17 @@ class CircuitBreaker:
259
 
260
  async def can_execute(self) -> bool:
261
  state = await self.redis.get_circuit_state(self.provider)
262
-
263
  if state["state"] == "open":
264
  if time.time() - state["last_failure"] > self.recovery_timeout:
265
  state["state"] = "half-open"
266
  state["failures"] = 0
267
  await self.redis.set_circuit_state(self.provider, state)
268
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(1)
269
- logger.info(f"Circuit breaker for {self.provider} entering half-open state")
270
  return True
271
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2)
272
  return False
273
-
274
- CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(
275
- 0 if state["state"] == "closed" else 1
276
- )
277
  return True
278
 
279
  async def record_success(self):
@@ -283,22 +284,101 @@ class CircuitBreaker:
283
  state["failures"] = 0
284
  await self.redis.set_circuit_state(self.provider, state)
285
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(0)
286
- logger.info(f"Circuit breaker for {self.provider} closed after recovery")
287
 
288
  async def record_failure(self):
289
  state = await self.redis.get_circuit_state(self.provider)
290
  state["failures"] += 1
291
  state["last_failure"] = time.time()
292
-
293
  if state["failures"] >= self.failure_threshold:
294
  state["state"] = "open"
295
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2)
296
- logger.warning(
297
- f"Circuit breaker for {self.provider} OPENED after "
298
- f"{state['failures']} failures"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  )
 
 
 
 
 
 
 
 
300
 
301
- await self.redis.set_circuit_state(self.provider, state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  # ---------------------------------------------------------------------------
304
  # Cost Tracking
@@ -306,8 +386,6 @@ class CircuitBreaker:
306
 
307
  @dataclass
308
  class CostTracker:
309
- """Per-session cost tracking with budget enforcement."""
310
-
311
  session_id: str
312
  budget_usd: float = BUDGET_USD_PER_SESSION
313
  spent_usd: float = 0.0
@@ -320,18 +398,13 @@ class CostTracker:
320
  def record_spend(self, cost_usd: float):
321
  self.spent_usd += cost_usd
322
  LLM_COST_USD.labels(provider=self.provider, model=self.model).inc(cost_usd)
323
- logger.info(
324
- f"Session {self.session_id}: spent ${cost_usd:.4f}, "
325
- f"total ${self.spent_usd:.4f} / ${self.budget_usd:.2f}"
326
- )
327
 
328
  # ---------------------------------------------------------------------------
329
- # Semaphore for Concurrency Control
330
  # ---------------------------------------------------------------------------
331
 
332
  class ConcurrencyLimiter:
333
- """Global concurrent request limiter."""
334
-
335
  def __init__(self, max_concurrent: int):
336
  self.semaphore = asyncio.Semaphore(max_concurrent)
337
 
@@ -346,7 +419,7 @@ class ConcurrencyLimiter:
346
  # ---------------------------------------------------------------------------
347
 
348
  class ChatRequest(BaseModel):
349
- model: str = Field(..., description="Model ID (e.g., nim/llama-3-8b)")
350
  messages: list[dict] = Field(..., description="OpenAI-compatible messages")
351
  temperature: Optional[float] = 0.7
352
  max_tokens: Optional[int] = 4096
@@ -354,18 +427,20 @@ class ChatRequest(BaseModel):
354
  tools: Optional[list[dict]] = None
355
  tool_choice: Optional[str] = "auto"
356
  session_id: Optional[str] = None
357
- api_key: Optional[str] = None # Provider-specific API key override
358
 
359
  class ChatResponse(BaseModel):
360
  id: str
361
  session_id: str
362
  model: str
 
363
  content: Optional[str] = None
364
  tool_calls: Optional[list[dict]] = None
365
  usage: dict = Field(default_factory=dict)
366
  cost_usd: float = 0.0
367
  cached: bool = False
368
  finish_reason: Optional[str] = None
 
369
 
370
  class HealthResponse(BaseModel):
371
  status: str
@@ -375,17 +450,16 @@ class HealthResponse(BaseModel):
375
  redis_connected: bool
376
  db_connected: bool
377
  circuit_breakers: dict[str, str]
378
-
379
- class MetricsResponse(BaseModel):
380
- prometheus: str
381
 
382
  # ---------------------------------------------------------------------------
383
- # Global State (set during lifespan)
384
  # ---------------------------------------------------------------------------
385
 
386
  db_pool: Optional[DatabasePool] = None
387
  redis_manager: Optional[RedisManager] = None
388
  concurrency_limiter: Optional[ConcurrencyLimiter] = None
 
389
  start_time: float = 0.0
390
  shutdown_event: asyncio.Event = asyncio.Event()
391
 
@@ -395,12 +469,10 @@ shutdown_event: asyncio.Event = asyncio.Event()
395
 
396
  @asynccontextmanager
397
  async def lifespan(app: FastAPI):
398
- """Application lifespan manager."""
399
- global db_pool, redis_manager, concurrency_limiter, start_time
400
 
401
  start_time = time.time()
402
 
403
- # Initialize connections
404
  db_pool = DatabasePool(DATABASE_URL)
405
  await db_pool.connect()
406
 
@@ -408,23 +480,25 @@ async def lifespan(app: FastAPI):
408
  await redis_manager.connect()
409
 
410
  concurrency_limiter = ConcurrencyLimiter(MAX_CONCURRENT_REQUESTS)
 
 
411
 
412
- # Graceful shutdown handler
413
  loop = asyncio.get_event_loop()
414
  for sig in (signal.SIGTERM, signal.SIGINT):
415
  loop.add_signal_handler(sig, lambda: asyncio.create_task(_shutdown()))
416
 
417
- # Initialize database schema
418
- await _init_schema()
419
 
420
  logger.info("ml-intern production server started")
421
 
422
  yield
423
 
424
- # Shutdown
425
  logger.info("Shutting down...")
426
  shutdown_event.set()
427
 
 
 
428
  if redis_manager:
429
  await redis_manager.disconnect()
430
  if db_pool:
@@ -437,7 +511,6 @@ async def _shutdown():
437
  shutdown_event.set()
438
 
439
  async def _init_schema():
440
- """Initialize database schema if not exists."""
441
  await db_pool.execute("""
442
  CREATE TABLE IF NOT EXISTS sessions (
443
  id TEXT PRIMARY KEY,
@@ -451,7 +524,7 @@ async def _init_schema():
451
  await db_pool.execute("""
452
  CREATE TABLE IF NOT EXISTS requests (
453
  id TEXT PRIMARY KEY,
454
- session_id TEXT REFERENCES sessions(id),
455
  model TEXT,
456
  provider TEXT,
457
  input_tokens INTEGER,
@@ -459,15 +532,7 @@ async def _init_schema():
459
  cost_usd NUMERIC,
460
  latency_ms INTEGER,
461
  cached BOOLEAN DEFAULT FALSE,
462
- created_at TIMESTAMP DEFAULT NOW()
463
- )
464
- """)
465
- await db_pool.execute("""
466
- CREATE TABLE IF NOT EXISTS circuit_events (
467
- id SERIAL PRIMARY KEY,
468
- provider TEXT,
469
- event_type TEXT,
470
- details JSONB,
471
  created_at TIMESTAMP DEFAULT NOW()
472
  )
473
  """)
@@ -475,7 +540,7 @@ async def _init_schema():
475
 
476
  app = FastAPI(
477
  title="ml-intern Production API",
478
- description="Production-grade API for ml-intern with rate limiting, caching, and multi-tenancy",
479
  version="1.0.0",
480
  lifespan=lifespan,
481
  )
@@ -495,11 +560,9 @@ app.add_middleware(
495
 
496
  @app.middleware("http")
497
  async def correlation_id_middleware(request: Request, call_next):
498
- """Add correlation ID to all requests."""
499
  correlation_id = request.headers.get("X-Correlation-ID", str(uuid.uuid4()))
500
  request.state.correlation_id = correlation_id
501
 
502
- # Set correlation ID in logger adapter
503
  old_factory = logging.getLogRecordFactory()
504
  def record_factory(*args, **kwargs):
505
  record = old_factory(*args, **kwargs)
@@ -531,58 +594,12 @@ async def correlation_id_middleware(request: Request, call_next):
531
  # Helper Functions
532
  # ---------------------------------------------------------------------------
533
 
534
- def get_provider_from_model(model: str) -> tuple[str, str]:
535
- """Extract provider and model name from model string."""
536
- if model.startswith("anthropic/"):
537
- return "anthropic", model
538
- elif model.startswith("openai/"):
539
- return "openai", model
540
- elif model.startswith("nim/"):
541
- return "nim", model.replace("nim/", "")
542
- elif model.startswith("ollama/"):
543
- return "ollama", model.replace("ollama/", "")
544
- elif model.startswith("groq/"):
545
- return "groq", model.replace("groq/", "")
546
- elif model.startswith("vllm/"):
547
- return "vllm", model.replace("vllm/", "")
548
- elif model.startswith("llamacpp/"):
549
- return "llamacpp", model.replace("llamacpp/", "")
550
- elif model.startswith("lmstudio/"):
551
- return "lmstudio", model.replace("lmstudio/", "")
552
- elif model.startswith("mlx/"):
553
- return "mlx", model.replace("mlx/", "")
554
- elif model.startswith("tgi/"):
555
- return "tgi", model.replace("tgi/", "")
556
- elif model.startswith("local/"):
557
- return "local", model.replace("local/", "")
558
- else:
559
- return "huggingface", model
560
-
561
- def estimate_cost(provider: str, model: str, input_tokens: int, output_tokens: int) -> float:
562
- """Estimate cost in USD based on provider pricing."""
563
- # Pricing per 1M tokens (rough estimates)
564
- pricing = {
565
- "anthropic": {"input": 15.0, "output": 75.0}, # Claude Opus 4
566
- "openai": {"input": 2.5, "output": 10.0}, # GPT-4o
567
- "groq": {"input": 0.0, "output": 0.0}, # Free tier
568
- "nim": {"input": 0.0, "output": 0.0}, # Free tier
569
- "huggingface": {"input": 0.0, "output": 0.0}, # Free credits
570
- "ollama": {"input": 0.0, "output": 0.0}, # Local
571
- "llamacpp": {"input": 0.0, "output": 0.0}, # Local
572
- "lmstudio": {"input": 0.0, "output": 0.0}, # Local
573
- "vllm": {"input": 0.0, "output": 0.0}, # Local
574
- "mlx": {"input": 0.0, "output": 0.0}, # Local
575
- "tgi": {"input": 0.0, "output": 0.0}, # Local
576
- "local": {"input": 0.0, "output": 0.0}, # Local
577
- }
578
-
579
- p = pricing.get(provider, {"input": 0.0, "output": 0.0})
580
- cost = (input_tokens / 1_000_000) * p["input"] + (output_tokens / 1_000_000) * p["output"]
581
  return cost
582
 
583
  def generate_cache_key(request: ChatRequest) -> str:
584
- """Generate deterministic cache key from request."""
585
- # Hash of messages + model + temperature (exclude stream)
586
  content = json.dumps({
587
  "model": request.model,
588
  "messages": request.messages,
@@ -592,13 +609,93 @@ def generate_cache_key(request: ChatRequest) -> str:
592
  }, sort_keys=True)
593
  return f"cache:llm:{hashlib.sha256(content.encode()).hexdigest()}"
594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  # ---------------------------------------------------------------------------
596
  # API Endpoints
597
  # ---------------------------------------------------------------------------
598
 
599
  @app.get("/health", response_model=HealthResponse)
600
  async def health_check():
601
- """Health check endpoint for load balancers and monitoring."""
602
  uptime = time.time() - start_time
603
 
604
  redis_ok = False
@@ -609,81 +706,70 @@ async def health_check():
609
  pass
610
 
611
  db_ok = False
612
- try:
613
- await db_pool.fetchval("SELECT 1")
614
- db_ok = True
615
- except Exception:
616
- pass
 
617
 
618
- # Get circuit breaker states
619
  circuits = {}
620
- for provider in ["anthropic", "openai", "groq", "nim", "huggingface", "ollama"]:
 
621
  try:
622
  state = await redis_manager.get_circuit_state(provider)
623
  circuits[provider] = state["state"]
 
624
  except Exception:
625
  circuits[provider] = "unknown"
 
626
 
627
  return HealthResponse(
628
- status="healthy" if redis_ok and db_ok else "degraded",
 
629
  uptime_seconds=uptime,
630
- active_sessions=0, # Would query from DB
631
  redis_connected=redis_ok,
632
  db_connected=db_ok,
633
  circuit_breakers=circuits,
 
634
  )
635
 
636
  @app.get("/metrics")
637
  async def metrics():
638
- """Prometheus metrics endpoint."""
639
  from starlette.responses import Response
640
  return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST)
641
 
642
  @app.post("/v1/chat/completions", response_model=ChatResponse)
643
  async def chat_completions(request: ChatRequest, background_tasks: BackgroundTasks):
644
- """OpenAI-compatible chat completions endpoint with production features."""
645
-
646
  correlation_id = getattr(request.state, "correlation_id", str(uuid.uuid4()))
647
  session_id = request.session_id or str(uuid.uuid4())
648
- provider, model_name = get_provider_from_model(request.model)
649
- request.state.provider = provider
650
 
651
- logger.info(
652
- f"Chat request: provider={provider}, model={model_name}, "
653
- f"stream={request.stream}, session={session_id}"
654
- )
655
 
656
- # 1. Concurrency limit
657
  await concurrency_limiter.acquire()
658
  try:
659
- # 2. Circuit breaker check
660
- breaker = CircuitBreaker(redis_manager, provider)
661
- if not await breaker.can_execute():
662
- logger.warning(f"Circuit breaker OPEN for {provider}")
663
- raise HTTPException(
664
- status_code=503,
665
- detail=f"Service temporarily unavailable for provider {provider}. "
666
- f"Circuit breaker is open. Try again later."
667
- )
668
 
669
- # 3. Rate limiting
670
- rpm = DEFAULT_RPM_LIMIT
671
- if provider == "nim":
672
- rpm = 40
673
- elif provider == "groq":
674
- rpm = 30
675
 
 
 
676
  rate_limit_key = f"{provider}:{session_id}"
677
  allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm)
678
  if not allowed:
679
  logger.warning(f"Rate limit exceeded for {rate_limit_key}")
680
- raise HTTPException(
681
- status_code=429,
682
- detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s",
683
- headers={"Retry-After": str(int(retry_after))},
684
- )
685
 
686
- # 4. Check cache for non-streaming requests
687
  if not request.stream:
688
  cache_key = generate_cache_key(request)
689
  cached = await redis_manager.get_cache(cache_key)
@@ -694,49 +780,29 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
694
  id=str(uuid.uuid4()),
695
  session_id=session_id,
696
  model=request.model,
 
697
  content=data.get("content"),
698
  tool_calls=data.get("tool_calls"),
699
  usage=data.get("usage", {}),
700
  cost_usd=0.0,
701
  cached=True,
702
  finish_reason=data.get("finish_reason"),
 
703
  )
704
 
705
- # 5. Budget check
706
- # TODO: Get session budget from DB
707
- cost_tracker = CostTracker(session_id, provider=provider, model=model_name)
708
-
709
- # 6. Call LLM (placeholder - would integrate with actual ml-intern agent)
710
- # For now, return a mock response with proper structure
711
- response_id = str(uuid.uuid4())
712
-
713
- # Simulate LLM call
714
- input_tokens = sum(len(m.get("content", "").split()) for m in request.messages) * 1.3
715
- output_tokens = request.max_tokens or 1000
716
 
717
- cost = estimate_cost(provider, model_name, int(input_tokens), output_tokens)
718
- cost_tracker.record_spend(cost)
719
 
720
- # Record success in circuit breaker
 
721
  await breaker.record_success()
722
 
723
- # Build response
724
- response = ChatResponse(
725
- id=response_id,
726
- session_id=session_id,
727
- model=request.model,
728
- content="This is a production-grade response from ml-intern.",
729
- usage={
730
- "prompt_tokens": int(input_tokens),
731
- "completion_tokens": output_tokens,
732
- "total_tokens": int(input_tokens) + output_tokens,
733
- },
734
- cost_usd=cost,
735
- cached=False,
736
- finish_reason="stop",
737
- )
738
 
739
- # 7. Cache response
740
  if not request.stream:
741
  cache_key = generate_cache_key(request)
742
  await redis_manager.set_cache(
@@ -749,8 +815,13 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
749
  }),
750
  )
751
 
752
- # 8. Persist to database (background)
753
- background_tasks.add_task(_persist_request, session_id, request, response)
 
 
 
 
 
754
 
755
  return response
756
 
@@ -758,72 +829,90 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
758
  raise
759
  except Exception as e:
760
  logger.exception(f"Error processing request: {e}")
761
- # Record failure in circuit breaker
762
- breaker = CircuitBreaker(redis_manager, provider)
763
  await breaker.record_failure()
764
  raise HTTPException(status_code=500, detail=str(e))
765
  finally:
766
  concurrency_limiter.release()
767
 
768
- async def _persist_request(session_id: str, request: ChatRequest, response: ChatResponse):
769
- """Persist request/response to database (background task)."""
770
  try:
771
  await db_pool.execute(
772
  """
773
  INSERT INTO requests (id, session_id, model, provider, input_tokens,
774
- output_tokens, cost_usd, latency_ms, cached)
775
- VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
776
  """,
777
  response.id,
778
  session_id,
779
  request.model,
780
- get_provider_from_model(request.model)[0],
781
  response.usage.get("prompt_tokens", 0),
782
  response.usage.get("completion_tokens", 0),
783
  response.cost_usd,
784
- 0, # latency would be measured
785
  response.cached,
 
786
  )
787
  except Exception as e:
788
  logger.error(f"Failed to persist request: {e}")
789
 
790
  @app.get("/v1/models")
791
  async def list_models():
792
- """List available models."""
793
- return {
794
- "object": "list",
795
- "data": [
796
- {"id": "anthropic/claude-opus-4-6", "object": "model", "owned_by": "anthropic"},
797
- {"id": "anthropic/claude-opus-4-7", "object": "model", "owned_by": "anthropic"},
798
- {"id": "openai/gpt-5.5", "object": "model", "owned_by": "openai"},
799
- {"id": "openai/gpt-5.4", "object": "model", "owned_by": "openai"},
800
- {"id": "groq/llama-3.3-70b-versatile", "object": "model", "owned_by": "groq"},
801
- {"id": "groq/llama-3.1-8b-instant", "object": "model", "owned_by": "groq"},
802
- {"id": "nim/llama-3-8b", "object": "model", "owned_by": "nvidia"},
803
  {"id": "nim/llama-3.1-405b-instruct", "object": "model", "owned_by": "nvidia"},
804
- {"id": "ollama/llama3.1", "object": "model", "owned_by": "ollama"},
805
- {"id": "vllm/llama-3-8b", "object": "model", "owned_by": "vllm"},
806
- {"id": "llamacpp/llama-3-8b", "object": "model", "owned_by": "llamacpp"},
807
- {"id": "lmstudio/llama-3-8b", "object": "model", "owned_by": "lmstudio"},
808
- {"id": "mlx/llama-3-8b", "object": "model", "owned_by": "mlx"},
809
- {"id": "tgi/llama-3-8b", "object": "model", "owned_by": "tgi"},
810
- {"id": "local/llama-3-8b", "object": "model", "owned_by": "local"},
811
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  }
813
 
814
- @app.delete("/v1/sessions/{session_id}")
815
- async def delete_session(session_id: str):
816
- """Delete a session and all its data."""
817
- # Clear cache entries for this session
818
- pattern = f"ratelimit:*:{session_id}"
819
- # Note: In production, use SCAN instead of KEYS
820
-
821
- await db_pool.execute(
822
- "UPDATE sessions SET metadata = jsonb_set(metadata, '{deleted}', 'true') WHERE id = $1",
823
- session_id,
824
- )
825
-
826
- return {"deleted": True, "session_id": session_id}
827
 
828
  # ---------------------------------------------------------------------------
829
  # Main Entry Point
 
5
  - FastAPI with async endpoints
6
  - Distributed rate limiting (Redis-backed token bucket)
7
  - Circuit breaker for external API resilience
8
+ - Automatic fallback: NIM (primary) -> Cloudflare Workers AI (fallback)
9
  - Request/response caching with Redis TTL
10
  - Multi-tenant session isolation
11
  - Health checks and graceful shutdown
12
  - Structured logging with correlation IDs
13
  - Cost tracking and budget enforcement
14
  - Connection pooling for all HTTP clients
15
+ - Cloudflare Workers AI support via OpenAI-compatible API
16
  """
17
 
18
  import asyncio
 
28
  from dataclasses import dataclass, field
29
  from typing import Any, Optional
30
 
31
+ import redis.asyncio as aioredis
32
  import asyncpg
33
  from fastapi import FastAPI, HTTPException, Request, Depends, BackgroundTasks
34
  from fastapi.middleware.cors import CORSMiddleware
35
  from fastapi.middleware.gzip import GZipMiddleware
 
36
  from pydantic import BaseModel, Field
37
  import uvicorn
38
  from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
39
+ import httpx
40
+ from tenacity import retry, stop_after_attempt, wait_exponential, RetryError
41
 
42
  # ---------------------------------------------------------------------------
43
  # Configuration
44
  # ---------------------------------------------------------------------------
45
 
46
  REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379")
47
+ DATABASE_URL = os.environ.get("DATABASE_URL", "")
48
  MAX_CONCURRENT_REQUESTS = int(os.environ.get("MAX_CONCURRENT_REQUESTS", "100"))
49
  DEFAULT_RPM_LIMIT = int(os.environ.get("DEFAULT_RPM_LIMIT", "40"))
50
  REQUEST_TIMEOUT = float(os.environ.get("REQUEST_TIMEOUT", "120"))
 
53
  CIRCUIT_BREAKER_FAILURE_THRESHOLD = int(os.environ.get("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5"))
54
  CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int(os.environ.get("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60"))
55
 
56
+ # Provider-specific endpoints
57
+ NIM_API_BASE = os.environ.get("NIM_API_BASE", "https://integrate.api.nvidia.com/v1")
58
+ CLOUDFLARE_API_KEY = os.environ.get("CLOUDFLARE_API_KEY", "")
59
+ CLOUDFLARE_ACCOUNT_ID = os.environ.get("CLOUDFLARE_ACCOUNT_ID", "")
60
+
61
+ # Fallback configuration
62
+ FALLBACK_ENABLED = os.environ.get("FALLBACK_ENABLED", "true").lower() == "true"
63
+ FALLBACK_PRIMARY = os.environ.get("FALLBACK_PRIMARY", "nim")
64
+ FALLBACK_SECONDARY = os.environ.get("FALLBACK_SECONDARY", "cloudflare")
65
+
66
+ # MLX (local Apple Silicon)
67
+ MLX_API_BASE = os.environ.get("MLX_API_BASE", "http://localhost:8000/v1")
68
+ MLX_ENABLED = os.environ.get("MLX_ENABLED", "false").lower() == "true"
69
+
70
  # ---------------------------------------------------------------------------
71
  # Prometheus Metrics
72
  # ---------------------------------------------------------------------------
 
106
  "Circuit breaker state (0=closed, 1=half-open, 2=open)",
107
  ["provider"],
108
  )
109
+ FALLBACK_COUNT = Counter(
110
+ "ml_intern_fallback_total",
111
+ "Fallback events between providers",
112
+ ["from_provider", "to_provider", "reason"],
113
+ )
114
 
115
  # ---------------------------------------------------------------------------
116
  # Structured Logging
117
  # ---------------------------------------------------------------------------
118
 
 
 
 
 
 
119
  logging.basicConfig(
120
  level=logging.INFO,
121
  format="%(asctime)s | %(levelname)s | correlation_id=%(correlation_id)s | %(name)s | %(message)s",
122
  handlers=[logging.StreamHandler(sys.stdout)],
123
  )
124
  logger = logging.getLogger("ml_intern.production")
125
+
126
+ class CorrelationIdFilter(logging.Filter):
127
+ def filter(self, record: logging.LogRecord) -> bool:
128
+ record.correlation_id = getattr(record, "correlation_id", "none")
129
+ return True
130
+
131
  logger.addFilter(CorrelationIdFilter())
132
 
133
  # ---------------------------------------------------------------------------
 
135
  # ---------------------------------------------------------------------------
136
 
137
  class DatabasePool:
 
 
138
  def __init__(self, dsn: str):
139
  self.dsn = dsn
140
  self._pool: Optional[asyncpg.Pool] = None
141
 
142
  async def connect(self):
143
+ if not self.dsn:
144
+ logger.warning("No DATABASE_URL set — skipping database connection")
145
+ return
146
  self._pool = await asyncpg.create_pool(
147
  self.dsn,
148
+ min_size=2,
149
+ max_size=10,
150
  command_timeout=60,
151
  )
152
  logger.info("Database pool connected")
 
157
  logger.info("Database pool disconnected")
158
 
159
  async def execute(self, query: str, *args):
160
+ if not self._pool:
161
+ return
162
  async with self._pool.acquire() as conn:
163
  return await conn.execute(query, *args)
164
 
165
  async def fetch(self, query: str, *args):
166
+ if not self._pool:
167
+ return []
168
  async with self._pool.acquire() as conn:
169
  return await conn.fetch(query, *args)
170
 
 
 
 
 
171
  async def fetchval(self, query: str, *args):
172
+ if not self._pool:
173
+ return None
174
  async with self._pool.acquire() as conn:
175
  return await conn.fetchval(query, *args)
176
 
177
  # ---------------------------------------------------------------------------
178
+ # Redis Layer
179
  # ---------------------------------------------------------------------------
180
 
181
  class RedisManager:
 
 
182
  def __init__(self, url: str):
183
  self.url = url
184
  self._redis: Optional[aioredis.Redis] = None
 
193
  await self._redis.close()
194
  logger.info("Redis disconnected")
195
 
 
 
196
  async def get_cache(self, key: str) -> Optional[str]:
197
  val = await self._redis.get(key)
198
  if val:
 
204
  async def set_cache(self, key: str, value: str, ttl: int = CACHE_TTL_SECONDS):
205
  await self._redis.setex(key, ttl, value)
206
 
 
 
 
 
 
207
  async def check_rate_limit(self, key: str, rpm: int) -> tuple[bool, float]:
 
208
  now = time.time()
209
  bucket_key = f"ratelimit:{key}"
 
 
210
  script = """
211
  local key = KEYS[1]
212
  local now = tonumber(ARGV[1])
213
  local rpm = tonumber(ARGV[2])
214
  local interval = 60.0 / rpm
 
215
  local last = redis.call('hget', key, 'last')
216
  local tokens = redis.call('hget', key, 'tokens')
 
217
  if not last then
218
  last = 0
219
  tokens = 1
 
221
  last = tonumber(last)
222
  tokens = tonumber(tokens)
223
  end
 
224
  local elapsed = now - last
225
  tokens = math.min(1, tokens + elapsed / interval)
 
226
  if tokens >= 1 then
227
  tokens = tokens - 1
228
  redis.call('hmset', key, 'last', now, 'tokens', tokens)
 
235
  return {0, retry_after}
236
  end
237
  """
 
238
  result = await self._redis.eval(script, 1, bucket_key, now, rpm)
239
  allowed = bool(result[0])
240
  retry_after = float(result[1]) if not allowed else 0.0
241
  return allowed, retry_after
242
 
 
 
243
  async def get_circuit_state(self, provider: str) -> dict:
244
  key = f"circuit:{provider}"
245
  val = await self._redis.get(key)
 
256
  # ---------------------------------------------------------------------------
257
 
258
  class CircuitBreaker:
 
 
259
  def __init__(self, redis: RedisManager, provider: str):
260
  self.redis = redis
261
  self.provider = provider
 
264
 
265
  async def can_execute(self) -> bool:
266
  state = await self.redis.get_circuit_state(self.provider)
 
267
  if state["state"] == "open":
268
  if time.time() - state["last_failure"] > self.recovery_timeout:
269
  state["state"] = "half-open"
270
  state["failures"] = 0
271
  await self.redis.set_circuit_state(self.provider, state)
272
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(1)
273
+ logger.info(f"Circuit breaker {self.provider} half-open")
274
  return True
275
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2)
276
  return False
277
+ CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(0 if state["state"] == "closed" else 1)
 
 
 
278
  return True
279
 
280
  async def record_success(self):
 
284
  state["failures"] = 0
285
  await self.redis.set_circuit_state(self.provider, state)
286
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(0)
287
+ logger.info(f"Circuit breaker {self.provider} closed")
288
 
289
  async def record_failure(self):
290
  state = await self.redis.get_circuit_state(self.provider)
291
  state["failures"] += 1
292
  state["last_failure"] = time.time()
 
293
  if state["failures"] >= self.failure_threshold:
294
  state["state"] = "open"
295
  CIRCUIT_BREAKER_STATE.labels(provider=self.provider).set(2)
296
+ logger.warning(f"Circuit breaker {self.provider} OPENED after {state['failures']} failures")
297
+ await self.redis.set_circuit_state(self.provider, state)
298
+
299
+ # ---------------------------------------------------------------------------
300
+ # Fallback Manager
301
+ # ---------------------------------------------------------------------------
302
+
303
+ @dataclass
304
+ class FallbackConfig:
305
+ primary: str = "nim"
306
+ secondary: str = "cloudflare"
307
+ enabled: bool = True
308
+
309
+ class FallbackManager:
310
+ def __init__(self, redis: RedisManager, config: FallbackConfig = None):
311
+ self.redis = redis
312
+ self.config = config or FallbackConfig()
313
+ self._http_client: Optional[httpx.AsyncClient] = None
314
+
315
+ async def init_client(self):
316
+ if not self._http_client:
317
+ self._http_client = httpx.AsyncClient(
318
+ limits=httpx.Limits(max_connections=50, max_keepalive_connections=20),
319
+ timeout=httpx.Timeout(REQUEST_TIMEOUT),
320
  )
321
+
322
+ async def close_client(self):
323
+ if self._http_client:
324
+ await self._http_client.aclose()
325
+
326
+ async def get_active_provider(self) -> tuple[str, dict]:
327
+ if not self.config.enabled:
328
+ return self.config.primary, self._get_provider_config(self.config.primary)
329
 
330
+ primary_breaker = CircuitBreaker(self.redis, self.config.primary)
331
+ if await primary_breaker.can_execute():
332
+ return self.config.primary, self._get_provider_config(self.config.primary)
333
+
334
+ secondary_breaker = CircuitBreaker(self.redis, self.config.secondary)
335
+ if await secondary_breaker.can_execute():
336
+ FALLBACK_COUNT.labels(
337
+ from_provider=self.config.primary,
338
+ to_provider=self.config.secondary,
339
+ reason="circuit_open",
340
+ ).inc()
341
+ logger.warning(f"Fallback: {self.config.primary} unavailable, switching to {self.config.secondary}")
342
+ return self.config.secondary, self._get_provider_config(self.config.secondary)
343
+
344
+ if MLX_ENABLED:
345
+ mlx_breaker = CircuitBreaker(self.redis, "mlx")
346
+ if await mlx_breaker.can_execute():
347
+ FALLBACK_COUNT.labels(
348
+ from_provider=self.config.primary,
349
+ to_provider="mlx",
350
+ reason="both_down",
351
+ ).inc()
352
+ logger.warning("Both cloud providers down — falling back to MLX local")
353
+ return "mlx", self._get_provider_config("mlx")
354
+
355
+ raise HTTPException(status_code=503, detail="All LLM providers unavailable.")
356
+
357
+ def _get_provider_config(self, provider: str) -> dict:
358
+ configs = {
359
+ "nim": {
360
+ "api_base": NIM_API_BASE,
361
+ "api_key": os.environ.get("NVIDIA_API_KEY", "no-key"),
362
+ "rpm_limit": 40,
363
+ "cost_per_1m_input": 0.0,
364
+ "cost_per_1m_output": 0.0,
365
+ },
366
+ "cloudflare": {
367
+ "api_base": f"https://api.cloudflare.com/client/v4/accounts/{CLOUDFLARE_ACCOUNT_ID}/ai/v1",
368
+ "api_key": CLOUDFLARE_API_KEY,
369
+ "rpm_limit": 100,
370
+ "cost_per_1m_input": 0.0,
371
+ "cost_per_1m_output": 0.0,
372
+ },
373
+ "mlx": {
374
+ "api_base": MLX_API_BASE,
375
+ "api_key": "no-key",
376
+ "rpm_limit": 1000,
377
+ "cost_per_1m_input": 0.0,
378
+ "cost_per_1m_output": 0.0,
379
+ },
380
+ }
381
+ return configs.get(provider, configs["nim"])
382
 
383
  # ---------------------------------------------------------------------------
384
  # Cost Tracking
 
386
 
387
  @dataclass
388
  class CostTracker:
 
 
389
  session_id: str
390
  budget_usd: float = BUDGET_USD_PER_SESSION
391
  spent_usd: float = 0.0
 
398
  def record_spend(self, cost_usd: float):
399
  self.spent_usd += cost_usd
400
  LLM_COST_USD.labels(provider=self.provider, model=self.model).inc(cost_usd)
401
+ logger.info(f"Session {self.session_id}: spent ${cost_usd:.4f}, total ${self.spent_usd:.4f} / ${self.budget_usd:.2f}")
 
 
 
402
 
403
  # ---------------------------------------------------------------------------
404
+ # Concurrency Limiter
405
  # ---------------------------------------------------------------------------
406
 
407
  class ConcurrencyLimiter:
 
 
408
  def __init__(self, max_concurrent: int):
409
  self.semaphore = asyncio.Semaphore(max_concurrent)
410
 
 
419
  # ---------------------------------------------------------------------------
420
 
421
  class ChatRequest(BaseModel):
422
+ model: str = Field(..., description="Model ID (e.g., @cf/meta/llama-3.1-8b-instruct)")
423
  messages: list[dict] = Field(..., description="OpenAI-compatible messages")
424
  temperature: Optional[float] = 0.7
425
  max_tokens: Optional[int] = 4096
 
427
  tools: Optional[list[dict]] = None
428
  tool_choice: Optional[str] = "auto"
429
  session_id: Optional[str] = None
430
+ provider_override: Optional[str] = None
431
 
432
  class ChatResponse(BaseModel):
433
  id: str
434
  session_id: str
435
  model: str
436
+ provider: str
437
  content: Optional[str] = None
438
  tool_calls: Optional[list[dict]] = None
439
  usage: dict = Field(default_factory=dict)
440
  cost_usd: float = 0.0
441
  cached: bool = False
442
  finish_reason: Optional[str] = None
443
+ fallback_used: bool = False
444
 
445
  class HealthResponse(BaseModel):
446
  status: str
 
450
  redis_connected: bool
451
  db_connected: bool
452
  circuit_breakers: dict[str, str]
453
+ fallback_status: dict[str, str]
 
 
454
 
455
  # ---------------------------------------------------------------------------
456
+ # Global State
457
  # ---------------------------------------------------------------------------
458
 
459
  db_pool: Optional[DatabasePool] = None
460
  redis_manager: Optional[RedisManager] = None
461
  concurrency_limiter: Optional[ConcurrencyLimiter] = None
462
+ fallback_manager: Optional[FallbackManager] = None
463
  start_time: float = 0.0
464
  shutdown_event: asyncio.Event = asyncio.Event()
465
 
 
469
 
470
  @asynccontextmanager
471
  async def lifespan(app: FastAPI):
472
+ global db_pool, redis_manager, concurrency_limiter, fallback_manager, start_time
 
473
 
474
  start_time = time.time()
475
 
 
476
  db_pool = DatabasePool(DATABASE_URL)
477
  await db_pool.connect()
478
 
 
480
  await redis_manager.connect()
481
 
482
  concurrency_limiter = ConcurrencyLimiter(MAX_CONCURRENT_REQUESTS)
483
+ fallback_manager = FallbackManager(redis_manager)
484
+ await fallback_manager.init_client()
485
 
 
486
  loop = asyncio.get_event_loop()
487
  for sig in (signal.SIGTERM, signal.SIGINT):
488
  loop.add_signal_handler(sig, lambda: asyncio.create_task(_shutdown()))
489
 
490
+ if DATABASE_URL:
491
+ await _init_schema()
492
 
493
  logger.info("ml-intern production server started")
494
 
495
  yield
496
 
 
497
  logger.info("Shutting down...")
498
  shutdown_event.set()
499
 
500
+ if fallback_manager:
501
+ await fallback_manager.close_client()
502
  if redis_manager:
503
  await redis_manager.disconnect()
504
  if db_pool:
 
511
  shutdown_event.set()
512
 
513
  async def _init_schema():
 
514
  await db_pool.execute("""
515
  CREATE TABLE IF NOT EXISTS sessions (
516
  id TEXT PRIMARY KEY,
 
524
  await db_pool.execute("""
525
  CREATE TABLE IF NOT EXISTS requests (
526
  id TEXT PRIMARY KEY,
527
+ session_id TEXT,
528
  model TEXT,
529
  provider TEXT,
530
  input_tokens INTEGER,
 
532
  cost_usd NUMERIC,
533
  latency_ms INTEGER,
534
  cached BOOLEAN DEFAULT FALSE,
535
+ fallback_used BOOLEAN DEFAULT FALSE,
 
 
 
 
 
 
 
 
536
  created_at TIMESTAMP DEFAULT NOW()
537
  )
538
  """)
 
540
 
541
  app = FastAPI(
542
  title="ml-intern Production API",
543
+ description="Production-grade API with NIM/Cloudflare fallback and MLX local support",
544
  version="1.0.0",
545
  lifespan=lifespan,
546
  )
 
560
 
561
  @app.middleware("http")
562
  async def correlation_id_middleware(request: Request, call_next):
 
563
  correlation_id = request.headers.get("X-Correlation-ID", str(uuid.uuid4()))
564
  request.state.correlation_id = correlation_id
565
 
 
566
  old_factory = logging.getLogRecordFactory()
567
  def record_factory(*args, **kwargs):
568
  record = old_factory(*args, **kwargs)
 
594
  # Helper Functions
595
  # ---------------------------------------------------------------------------
596
 
597
+ def estimate_cost(provider_config: dict, input_tokens: int, output_tokens: int) -> float:
598
+ cost = (input_tokens / 1_000_000) * provider_config.get("cost_per_1m_input", 0.0)
599
+ cost += (output_tokens / 1_000_000) * provider_config.get("cost_per_1m_output", 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  return cost
601
 
602
  def generate_cache_key(request: ChatRequest) -> str:
 
 
603
  content = json.dumps({
604
  "model": request.model,
605
  "messages": request.messages,
 
609
  }, sort_keys=True)
610
  return f"cache:llm:{hashlib.sha256(content.encode()).hexdigest()}"
611
 
612
+ # ---------------------------------------------------------------------------
613
+ # LLM Call Implementation
614
+ # ---------------------------------------------------------------------------
615
+
616
+ async def call_llm(
617
+ provider: str,
618
+ provider_config: dict,
619
+ request: ChatRequest,
620
+ session_id: str,
621
+ ) -> ChatResponse:
622
+ if not fallback_manager or not fallback_manager._http_client:
623
+ raise HTTPException(status_code=503, detail="HTTP client not initialized")
624
+
625
+ api_base = provider_config["api_base"]
626
+ api_key = provider_config["api_key"]
627
+
628
+ payload = {
629
+ "model": request.model,
630
+ "messages": request.messages,
631
+ "temperature": request.temperature,
632
+ "max_tokens": request.max_tokens,
633
+ "stream": False,
634
+ }
635
+
636
+ headers = {
637
+ "Content-Type": "application/json",
638
+ "Authorization": f"Bearer {api_key}",
639
+ }
640
+
641
+ if provider == "cloudflare":
642
+ headers["Authorization"] = f"Bearer {api_key}"
643
+
644
+ start_time = time.time()
645
+
646
+ try:
647
+ response = await fallback_manager._http_client.post(
648
+ f"{api_base}/chat/completions",
649
+ json=payload,
650
+ headers=headers,
651
+ timeout=REQUEST_TIMEOUT,
652
+ )
653
+ response.raise_for_status()
654
+ data = response.json()
655
+
656
+ latency_ms = int((time.time() - start_time) * 1000)
657
+
658
+ usage = data.get("usage", {})
659
+ input_tokens = usage.get("prompt_tokens", 0)
660
+ output_tokens = usage.get("completion_tokens", 0)
661
+
662
+ cost = estimate_cost(provider_config, input_tokens, output_tokens)
663
+
664
+ content = None
665
+ tool_calls = None
666
+ if "choices" in data and len(data["choices"]) > 0:
667
+ choice = data["choices"][0]
668
+ message = choice.get("message", {})
669
+ content = message.get("content")
670
+ tool_calls = message.get("tool_calls")
671
+
672
+ return ChatResponse(
673
+ id=data.get("id", str(uuid.uuid4())),
674
+ session_id=session_id,
675
+ model=request.model,
676
+ provider=provider,
677
+ content=content,
678
+ tool_calls=tool_calls,
679
+ usage=usage,
680
+ cost_usd=cost,
681
+ cached=False,
682
+ finish_reason=data.get("choices", [{}])[0].get("finish_reason"),
683
+ fallback_used=False,
684
+ )
685
+
686
+ except httpx.HTTPStatusError as e:
687
+ logger.error(f"HTTP error from {provider}: {e.response.status_code} - {e.response.text[:200]}")
688
+ raise HTTPException(status_code=502, detail=f"Provider {provider} returned HTTP {e.response.status_code}")
689
+ except httpx.RequestError as e:
690
+ logger.error(f"Network error calling {provider}: {e}")
691
+ raise HTTPException(status_code=503, detail=f"Cannot reach provider {provider}: {str(e)}")
692
+
693
  # ---------------------------------------------------------------------------
694
  # API Endpoints
695
  # ---------------------------------------------------------------------------
696
 
697
  @app.get("/health", response_model=HealthResponse)
698
  async def health_check():
 
699
  uptime = time.time() - start_time
700
 
701
  redis_ok = False
 
706
  pass
707
 
708
  db_ok = False
709
+ if DATABASE_URL:
710
+ try:
711
+ await db_pool.fetchval("SELECT 1")
712
+ db_ok = True
713
+ except Exception:
714
+ pass
715
 
 
716
  circuits = {}
717
+ fallback_status = {}
718
+ for provider in ["nim", "cloudflare", "mlx"]:
719
  try:
720
  state = await redis_manager.get_circuit_state(provider)
721
  circuits[provider] = state["state"]
722
+ fallback_status[provider] = "up" if state["state"] == "closed" else "down"
723
  except Exception:
724
  circuits[provider] = "unknown"
725
+ fallback_status[provider] = "unknown"
726
 
727
  return HealthResponse(
728
+ status="healthy" if redis_ok else "degraded",
729
+ version="1.0.0",
730
  uptime_seconds=uptime,
731
+ active_sessions=0,
732
  redis_connected=redis_ok,
733
  db_connected=db_ok,
734
  circuit_breakers=circuits,
735
+ fallback_status=fallback_status,
736
  )
737
 
738
  @app.get("/metrics")
739
  async def metrics():
 
740
  from starlette.responses import Response
741
  return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST)
742
 
743
  @app.post("/v1/chat/completions", response_model=ChatResponse)
744
  async def chat_completions(request: ChatRequest, background_tasks: BackgroundTasks):
 
 
745
  correlation_id = getattr(request.state, "correlation_id", str(uuid.uuid4()))
746
  session_id = request.session_id or str(uuid.uuid4())
 
 
747
 
748
+ logger.info(f"Chat request: model={request.model}, stream={request.stream}, session={session_id}")
 
 
 
749
 
 
750
  await concurrency_limiter.acquire()
751
  try:
752
+ # 1. Determine provider
753
+ if request.provider_override:
754
+ provider = request.provider_override
755
+ provider_config = fallback_manager._get_provider_config(provider)
756
+ breaker = CircuitBreaker(redis_manager, provider)
757
+ if not await breaker.can_execute():
758
+ raise HTTPException(status_code=503, detail=f"Provider {provider} circuit breaker is open")
759
+ else:
760
+ provider, provider_config = await fallback_manager.get_active_provider()
761
 
762
+ request.state.provider = provider
 
 
 
 
 
763
 
764
+ # 2. Rate limiting
765
+ rpm = provider_config.get("rpm_limit", DEFAULT_RPM_LIMIT)
766
  rate_limit_key = f"{provider}:{session_id}"
767
  allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm)
768
  if not allowed:
769
  logger.warning(f"Rate limit exceeded for {rate_limit_key}")
770
+ raise HTTPException(status_code=429, detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s", headers={"Retry-After": str(int(retry_after))})
 
 
 
 
771
 
772
+ # 3. Check cache
773
  if not request.stream:
774
  cache_key = generate_cache_key(request)
775
  cached = await redis_manager.get_cache(cache_key)
 
780
  id=str(uuid.uuid4()),
781
  session_id=session_id,
782
  model=request.model,
783
+ provider=provider,
784
  content=data.get("content"),
785
  tool_calls=data.get("tool_calls"),
786
  usage=data.get("usage", {}),
787
  cost_usd=0.0,
788
  cached=True,
789
  finish_reason=data.get("finish_reason"),
790
+ fallback_used=False,
791
  )
792
 
793
+ # 4. Budget
794
+ cost_tracker = CostTracker(session_id, provider=provider, model=request.model)
 
 
 
 
 
 
 
 
 
795
 
796
+ # 5. Call LLM
797
+ response = await call_llm(provider, provider_config, request, session_id)
798
 
799
+ # Record success
800
+ breaker = CircuitBreaker(redis_manager, provider)
801
  await breaker.record_success()
802
 
803
+ cost_tracker.record_spend(response.cost_usd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
+ # 6. Cache response
806
  if not request.stream:
807
  cache_key = generate_cache_key(request)
808
  await redis_manager.set_cache(
 
815
  }),
816
  )
817
 
818
+ # 7. Persist
819
+ if DATABASE_URL:
820
+ background_tasks.add_task(_persist_request, session_id, request, response, provider)
821
+
822
+ # Mark fallback
823
+ if provider != FALLBACK_PRIMARY and FALLBACK_ENABLED:
824
+ response.fallback_used = True
825
 
826
  return response
827
 
 
829
  raise
830
  except Exception as e:
831
  logger.exception(f"Error processing request: {e}")
832
+ breaker = CircuitBreaker(redis_manager, provider if 'provider' in locals() else "unknown")
 
833
  await breaker.record_failure()
834
  raise HTTPException(status_code=500, detail=str(e))
835
  finally:
836
  concurrency_limiter.release()
837
 
838
+ async def _persist_request(session_id: str, request: ChatRequest, response: ChatResponse, provider: str):
 
839
  try:
840
  await db_pool.execute(
841
  """
842
  INSERT INTO requests (id, session_id, model, provider, input_tokens,
843
+ output_tokens, cost_usd, latency_ms, cached, fallback_used)
844
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
845
  """,
846
  response.id,
847
  session_id,
848
  request.model,
849
+ provider,
850
  response.usage.get("prompt_tokens", 0),
851
  response.usage.get("completion_tokens", 0),
852
  response.cost_usd,
853
+ 0,
854
  response.cached,
855
+ response.fallback_used,
856
  )
857
  except Exception as e:
858
  logger.error(f"Failed to persist request: {e}")
859
 
860
  @app.get("/v1/models")
861
  async def list_models():
862
+ models = []
863
+
864
+ if os.environ.get("NVIDIA_API_KEY"):
865
+ models.extend([
 
 
 
 
 
 
 
866
  {"id": "nim/llama-3.1-405b-instruct", "object": "model", "owned_by": "nvidia"},
867
+ {"id": "nim/llama-3.1-70b-instruct", "object": "model", "owned_by": "nvidia"},
868
+ {"id": "nim/llama-3.1-8b-instruct", "object": "model", "owned_by": "nvidia"},
869
+ {"id": "nim/mistral-7b-instruct", "object": "model", "owned_by": "nvidia"},
870
+ ])
871
+
872
+ if CLOUDFLARE_API_KEY and CLOUDFLARE_ACCOUNT_ID:
873
+ models.extend([
874
+ {"id": "cloudflare/@cf/meta/llama-3.1-8b-instruct", "object": "model", "owned_by": "cloudflare"},
875
+ {"id": "cloudflare/@cf/meta/llama-3.1-70b-instruct", "object": "model", "owned_by": "cloudflare"},
876
+ {"id": "cloudflare/@cf/mistral/mistral-7b-instruct", "object": "model", "owned_by": "cloudflare"},
877
+ {"id": "cloudflare/@cf/qwen/qwen1.5-14b-chat-awq", "object": "model", "owned_by": "cloudflare"},
878
+ ])
879
+
880
+ if MLX_ENABLED:
881
+ models.extend([
882
+ {"id": "mlx/llama-3.1-8b", "object": "model", "owned_by": "mlx"},
883
+ {"id": "mlx/llama-3.1-70b", "object": "model", "owned_by": "mlx"},
884
+ ])
885
+
886
+ return {"object": "list", "data": models}
887
+
888
+ @app.get("/v1/fallback/status")
889
+ async def fallback_status():
890
+ status = {}
891
+ for provider in ["nim", "cloudflare", "mlx"]:
892
+ breaker = CircuitBreaker(redis_manager, provider)
893
+ can_execute = await breaker.can_execute()
894
+ state = await redis_manager.get_circuit_state(provider)
895
+ status[provider] = {
896
+ "circuit_state": state["state"],
897
+ "failures": state["failures"],
898
+ "available": can_execute,
899
+ "last_failure": state["last_failure"],
900
+ }
901
+
902
+ return {
903
+ "fallback_enabled": FALLBACK_ENABLED,
904
+ "primary": FALLBACK_PRIMARY,
905
+ "secondary": FALLBACK_SECONDARY,
906
+ "providers": status,
907
+ "active_provider": await _get_active_provider_name(),
908
  }
909
 
910
+ async def _get_active_provider_name() -> str:
911
+ try:
912
+ provider, _ = await fallback_manager.get_active_provider()
913
+ return provider
914
+ except HTTPException:
915
+ return "none_available"
 
 
 
 
 
 
 
916
 
917
  # ---------------------------------------------------------------------------
918
  # Main Entry Point