raazkumar commited on
Commit
42855cf
·
verified ·
1 Parent(s): e979df8

Upload production/production_server.py

Browse files
Files changed (1) hide show
  1. production/production_server.py +52 -18
production/production_server.py CHANGED
@@ -5,7 +5,7 @@ Features:
5
  - FastAPI with async endpoints
6
  - Distributed rate limiting (Redis-backed token bucket)
7
  - Circuit breaker for external API resilience
8
- - Automatic fallback: NIM (primary) -> Cloudflare Workers AI (fallback)
9
  - Request/response caching with Redis TTL
10
  - Multi-tenant session isolation
11
  - Health checks and graceful shutdown
@@ -13,6 +13,8 @@ Features:
13
  - Cost tracking and budget enforcement
14
  - Connection pooling for all HTTP clients
15
  - Cloudflare Workers AI support via OpenAI-compatible API
 
 
16
  """
17
 
18
  import asyncio
@@ -58,10 +60,15 @@ NIM_API_BASE = os.environ.get("NIM_API_BASE", "https://integrate.api.nvidia.com/
58
  CLOUDFLARE_API_KEY = os.environ.get("CLOUDFLARE_API_KEY", "")
59
  CLOUDFLARE_ACCOUNT_ID = os.environ.get("CLOUDFLARE_ACCOUNT_ID", "")
60
 
 
 
 
 
61
  # Fallback configuration
62
  FALLBACK_ENABLED = os.environ.get("FALLBACK_ENABLED", "true").lower() == "true"
63
  FALLBACK_PRIMARY = os.environ.get("FALLBACK_PRIMARY", "nim")
64
  FALLBACK_SECONDARY = os.environ.get("FALLBACK_SECONDARY", "cloudflare")
 
65
 
66
  # MLX (local Apple Silicon)
67
  MLX_API_BASE = os.environ.get("MLX_API_BASE", "http://localhost:8000/v1")
@@ -304,6 +311,7 @@ class CircuitBreaker:
304
  class FallbackConfig:
305
  primary: str = "nim"
306
  secondary: str = "cloudflare"
 
307
  enabled: bool = True
308
 
309
  class FallbackManager:
@@ -338,18 +346,33 @@ class FallbackManager:
338
  to_provider=self.config.secondary,
339
  reason="circuit_open",
340
  ).inc()
341
- logger.warning(f"Fallback: {self.config.primary} unavailable, switching to {self.config.secondary}")
 
 
342
  return self.config.secondary, self._get_provider_config(self.config.secondary)
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  if MLX_ENABLED:
345
  mlx_breaker = CircuitBreaker(self.redis, "mlx")
346
  if await mlx_breaker.can_execute():
347
  FALLBACK_COUNT.labels(
348
- from_provider=self.config.primary,
349
  to_provider="mlx",
350
- reason="both_down",
351
  ).inc()
352
- logger.warning("Both cloud providers down — falling back to MLX local")
353
  return "mlx", self._get_provider_config("mlx")
354
 
355
  raise HTTPException(status_code=503, detail="All LLM providers unavailable.")
@@ -370,6 +393,13 @@ class FallbackManager:
370
  "cost_per_1m_input": 0.0,
371
  "cost_per_1m_output": 0.0,
372
  },
 
 
 
 
 
 
 
373
  "mlx": {
374
  "api_base": MLX_API_BASE,
375
  "api_key": "no-key",
@@ -419,7 +449,7 @@ class ConcurrencyLimiter:
419
  # ---------------------------------------------------------------------------
420
 
421
  class ChatRequest(BaseModel):
422
- model: str = Field(..., description="Model ID (e.g., @cf/meta/llama-3.1-8b-instruct)")
423
  messages: list[dict] = Field(..., description="OpenAI-compatible messages")
424
  temperature: Optional[float] = 0.7
425
  max_tokens: Optional[int] = 4096
@@ -540,7 +570,7 @@ async def _init_schema():
540
 
541
  app = FastAPI(
542
  title="ml-intern Production API",
543
- description="Production-grade API with NIM/Cloudflare fallback and MLX local support",
544
  version="1.0.0",
545
  lifespan=lifespan,
546
  )
@@ -715,7 +745,7 @@ async def health_check():
715
 
716
  circuits = {}
717
  fallback_status = {}
718
- for provider in ["nim", "cloudflare", "mlx"]:
719
  try:
720
  state = await redis_manager.get_circuit_state(provider)
721
  circuits[provider] = state["state"]
@@ -749,7 +779,6 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
749
 
750
  await concurrency_limiter.acquire()
751
  try:
752
- # 1. Determine provider
753
  if request.provider_override:
754
  provider = request.provider_override
755
  provider_config = fallback_manager._get_provider_config(provider)
@@ -761,7 +790,6 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
761
 
762
  request.state.provider = provider
763
 
764
- # 2. Rate limiting
765
  rpm = provider_config.get("rpm_limit", DEFAULT_RPM_LIMIT)
766
  rate_limit_key = f"{provider}:{session_id}"
767
  allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm)
@@ -769,7 +797,6 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
769
  logger.warning(f"Rate limit exceeded for {rate_limit_key}")
770
  raise HTTPException(status_code=429, detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s", headers={"Retry-After": str(int(retry_after))})
771
 
772
- # 3. Check cache
773
  if not request.stream:
774
  cache_key = generate_cache_key(request)
775
  cached = await redis_manager.get_cache(cache_key)
@@ -790,19 +817,15 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
790
  fallback_used=False,
791
  )
792
 
793
- # 4. Budget
794
  cost_tracker = CostTracker(session_id, provider=provider, model=request.model)
795
 
796
- # 5. Call LLM
797
  response = await call_llm(provider, provider_config, request, session_id)
798
 
799
- # Record success
800
  breaker = CircuitBreaker(redis_manager, provider)
801
  await breaker.record_success()
802
 
803
  cost_tracker.record_spend(response.cost_usd)
804
 
805
- # 6. Cache response
806
  if not request.stream:
807
  cache_key = generate_cache_key(request)
808
  await redis_manager.set_cache(
@@ -815,11 +838,9 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
815
  }),
816
  )
817
 
818
- # 7. Persist
819
  if DATABASE_URL:
820
  background_tasks.add_task(_persist_request, session_id, request, response, provider)
821
 
822
- # Mark fallback
823
  if provider != FALLBACK_PRIMARY and FALLBACK_ENABLED:
824
  response.fallback_used = True
825
 
@@ -875,12 +896,24 @@ async def list_models():
875
  {"id": "cloudflare/@cf/meta/llama-3.1-70b-instruct", "object": "model", "owned_by": "cloudflare"},
876
  {"id": "cloudflare/@cf/mistral/mistral-7b-instruct", "object": "model", "owned_by": "cloudflare"},
877
  {"id": "cloudflare/@cf/qwen/qwen1.5-14b-chat-awq", "object": "model", "owned_by": "cloudflare"},
 
 
 
 
 
 
 
 
 
878
  ])
879
 
880
  if MLX_ENABLED:
881
  models.extend([
882
  {"id": "mlx/llama-3.1-8b", "object": "model", "owned_by": "mlx"},
883
  {"id": "mlx/llama-3.1-70b", "object": "model", "owned_by": "mlx"},
 
 
 
884
  ])
885
 
886
  return {"object": "list", "data": models}
@@ -888,7 +921,7 @@ async def list_models():
888
  @app.get("/v1/fallback/status")
889
  async def fallback_status():
890
  status = {}
891
- for provider in ["nim", "cloudflare", "mlx"]:
892
  breaker = CircuitBreaker(redis_manager, provider)
893
  can_execute = await breaker.can_execute()
894
  state = await redis_manager.get_circuit_state(provider)
@@ -903,6 +936,7 @@ async def fallback_status():
903
  "fallback_enabled": FALLBACK_ENABLED,
904
  "primary": FALLBACK_PRIMARY,
905
  "secondary": FALLBACK_SECONDARY,
 
906
  "providers": status,
907
  "active_provider": await _get_active_provider_name(),
908
  }
 
5
  - FastAPI with async endpoints
6
  - Distributed rate limiting (Redis-backed token bucket)
7
  - Circuit breaker for external API resilience
8
+ - Automatic fallback: NIM (primary) -> Cloudflare Workers AI (secondary) -> Gemini (tertiary) -> MLX
9
  - Request/response caching with Redis TTL
10
  - Multi-tenant session isolation
11
  - Health checks and graceful shutdown
 
13
  - Cost tracking and budget enforcement
14
  - Connection pooling for all HTTP clients
15
  - Cloudflare Workers AI support via OpenAI-compatible API
16
+ - Google Gemini support via OpenAI-compatible API
17
+ - MLX local support for Apple Silicon
18
  """
19
 
20
  import asyncio
 
60
  CLOUDFLARE_API_KEY = os.environ.get("CLOUDFLARE_API_KEY", "")
61
  CLOUDFLARE_ACCOUNT_ID = os.environ.get("CLOUDFLARE_ACCOUNT_ID", "")
62
 
63
+ # Google Gemini / AI Studio
64
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
65
+ GEMINI_API_BASE = os.environ.get("GEMINI_API_BASE", "https://generativelanguage.googleapis.com/v1beta/openai")
66
+
67
  # Fallback configuration
68
  FALLBACK_ENABLED = os.environ.get("FALLBACK_ENABLED", "true").lower() == "true"
69
  FALLBACK_PRIMARY = os.environ.get("FALLBACK_PRIMARY", "nim")
70
  FALLBACK_SECONDARY = os.environ.get("FALLBACK_SECONDARY", "cloudflare")
71
+ FALLBACK_TERTIARY = os.environ.get("FALLBACK_TERTIARY", "gemini")
72
 
73
  # MLX (local Apple Silicon)
74
  MLX_API_BASE = os.environ.get("MLX_API_BASE", "http://localhost:8000/v1")
 
311
  class FallbackConfig:
312
  primary: str = "nim"
313
  secondary: str = "cloudflare"
314
+ tertiary: str = "gemini"
315
  enabled: bool = True
316
 
317
  class FallbackManager:
 
346
  to_provider=self.config.secondary,
347
  reason="circuit_open",
348
  ).inc()
349
+ logger.warning(
350
+ f"Fallback: {self.config.primary} unavailable, switching to {self.config.secondary}"
351
+ )
352
  return self.config.secondary, self._get_provider_config(self.config.secondary)
353
 
354
+ tertiary_breaker = CircuitBreaker(self.redis, self.config.tertiary)
355
+ if await tertiary_breaker.can_execute():
356
+ FALLBACK_COUNT.labels(
357
+ from_provider=self.config.secondary,
358
+ to_provider=self.config.tertiary,
359
+ reason="secondary_down",
360
+ ).inc()
361
+ logger.warning(
362
+ f"Fallback: both {self.config.primary} and {self.config.secondary} down, "
363
+ f"switching to {self.config.tertiary}"
364
+ )
365
+ return self.config.tertiary, self._get_provider_config(self.config.tertiary)
366
+
367
  if MLX_ENABLED:
368
  mlx_breaker = CircuitBreaker(self.redis, "mlx")
369
  if await mlx_breaker.can_execute():
370
  FALLBACK_COUNT.labels(
371
+ from_provider=self.config.tertiary,
372
  to_provider="mlx",
373
+ reason="all_cloud_down",
374
  ).inc()
375
+ logger.warning("All cloud providers down — falling back to MLX local")
376
  return "mlx", self._get_provider_config("mlx")
377
 
378
  raise HTTPException(status_code=503, detail="All LLM providers unavailable.")
 
393
  "cost_per_1m_input": 0.0,
394
  "cost_per_1m_output": 0.0,
395
  },
396
+ "gemini": {
397
+ "api_base": GEMINI_API_BASE,
398
+ "api_key": GEMINI_API_KEY,
399
+ "rpm_limit": 60,
400
+ "cost_per_1m_input": 0.075,
401
+ "cost_per_1m_output": 0.30,
402
+ },
403
  "mlx": {
404
  "api_base": MLX_API_BASE,
405
  "api_key": "no-key",
 
449
  # ---------------------------------------------------------------------------
450
 
451
  class ChatRequest(BaseModel):
452
+ model: str = Field(..., description="Model ID (e.g., gemma-4-31b-bf16)")
453
  messages: list[dict] = Field(..., description="OpenAI-compatible messages")
454
  temperature: Optional[float] = 0.7
455
  max_tokens: Optional[int] = 4096
 
570
 
571
  app = FastAPI(
572
  title="ml-intern Production API",
573
+ description="Production-grade API with NIM/Cloudflare/Gemini fallback and MLX local support",
574
  version="1.0.0",
575
  lifespan=lifespan,
576
  )
 
745
 
746
  circuits = {}
747
  fallback_status = {}
748
+ for provider in ["nim", "cloudflare", "gemini", "mlx"]:
749
  try:
750
  state = await redis_manager.get_circuit_state(provider)
751
  circuits[provider] = state["state"]
 
779
 
780
  await concurrency_limiter.acquire()
781
  try:
 
782
  if request.provider_override:
783
  provider = request.provider_override
784
  provider_config = fallback_manager._get_provider_config(provider)
 
790
 
791
  request.state.provider = provider
792
 
 
793
  rpm = provider_config.get("rpm_limit", DEFAULT_RPM_LIMIT)
794
  rate_limit_key = f"{provider}:{session_id}"
795
  allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm)
 
797
  logger.warning(f"Rate limit exceeded for {rate_limit_key}")
798
  raise HTTPException(status_code=429, detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s", headers={"Retry-After": str(int(retry_after))})
799
 
 
800
  if not request.stream:
801
  cache_key = generate_cache_key(request)
802
  cached = await redis_manager.get_cache(cache_key)
 
817
  fallback_used=False,
818
  )
819
 
 
820
  cost_tracker = CostTracker(session_id, provider=provider, model=request.model)
821
 
 
822
  response = await call_llm(provider, provider_config, request, session_id)
823
 
 
824
  breaker = CircuitBreaker(redis_manager, provider)
825
  await breaker.record_success()
826
 
827
  cost_tracker.record_spend(response.cost_usd)
828
 
 
829
  if not request.stream:
830
  cache_key = generate_cache_key(request)
831
  await redis_manager.set_cache(
 
838
  }),
839
  )
840
 
 
841
  if DATABASE_URL:
842
  background_tasks.add_task(_persist_request, session_id, request, response, provider)
843
 
 
844
  if provider != FALLBACK_PRIMARY and FALLBACK_ENABLED:
845
  response.fallback_used = True
846
 
 
896
  {"id": "cloudflare/@cf/meta/llama-3.1-70b-instruct", "object": "model", "owned_by": "cloudflare"},
897
  {"id": "cloudflare/@cf/mistral/mistral-7b-instruct", "object": "model", "owned_by": "cloudflare"},
898
  {"id": "cloudflare/@cf/qwen/qwen1.5-14b-chat-awq", "object": "model", "owned_by": "cloudflare"},
899
+ {"id": "cloudflare/@cf/google/gemma-4-26b-a4b-it", "object": "model", "owned_by": "cloudflare"},
900
+ ])
901
+
902
+ if GEMINI_API_KEY:
903
+ models.extend([
904
+ {"id": "gemini/gemini-2.5-pro-preview", "object": "model", "owned_by": "google"},
905
+ {"id": "gemini/gemini-2.5-flash-preview", "object": "model", "owned_by": "google"},
906
+ {"id": "gemini/gemma-4-26b", "object": "model", "owned_by": "google"},
907
+ {"id": "gemini/gemma-4-9b", "object": "model", "owned_by": "google"},
908
  ])
909
 
910
  if MLX_ENABLED:
911
  models.extend([
912
  {"id": "mlx/llama-3.1-8b", "object": "model", "owned_by": "mlx"},
913
  {"id": "mlx/llama-3.1-70b", "object": "model", "owned_by": "mlx"},
914
+ {"id": "mlx/gemma-4-26b-a4b-it", "object": "model", "owned_by": "mlx"},
915
+ {"id": "mlx/gemma-4-31b-bf16", "object": "model", "owned_by": "mlx"},
916
+ {"id": "mlx/gemma-4-e4b-it", "object": "model", "owned_by": "mlx"},
917
  ])
918
 
919
  return {"object": "list", "data": models}
 
921
  @app.get("/v1/fallback/status")
922
  async def fallback_status():
923
  status = {}
924
+ for provider in ["nim", "cloudflare", "gemini", "mlx"]:
925
  breaker = CircuitBreaker(redis_manager, provider)
926
  can_execute = await breaker.can_execute()
927
  state = await redis_manager.get_circuit_state(provider)
 
936
  "fallback_enabled": FALLBACK_ENABLED,
937
  "primary": FALLBACK_PRIMARY,
938
  "secondary": FALLBACK_SECONDARY,
939
+ "tertiary": FALLBACK_TERTIARY,
940
  "providers": status,
941
  "active_provider": await _get_active_provider_name(),
942
  }