Upload production/production_server.py
Browse files- production/production_server.py +52 -18
production/production_server.py
CHANGED
|
@@ -5,7 +5,7 @@ Features:
|
|
| 5 |
- FastAPI with async endpoints
|
| 6 |
- Distributed rate limiting (Redis-backed token bucket)
|
| 7 |
- Circuit breaker for external API resilience
|
| 8 |
-
- Automatic fallback: NIM (primary) -> Cloudflare Workers AI (
|
| 9 |
- Request/response caching with Redis TTL
|
| 10 |
- Multi-tenant session isolation
|
| 11 |
- Health checks and graceful shutdown
|
|
@@ -13,6 +13,8 @@ Features:
|
|
| 13 |
- Cost tracking and budget enforcement
|
| 14 |
- Connection pooling for all HTTP clients
|
| 15 |
- Cloudflare Workers AI support via OpenAI-compatible API
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
|
| 18 |
import asyncio
|
|
@@ -58,10 +60,15 @@ NIM_API_BASE = os.environ.get("NIM_API_BASE", "https://integrate.api.nvidia.com/
|
|
| 58 |
CLOUDFLARE_API_KEY = os.environ.get("CLOUDFLARE_API_KEY", "")
|
| 59 |
CLOUDFLARE_ACCOUNT_ID = os.environ.get("CLOUDFLARE_ACCOUNT_ID", "")
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# Fallback configuration
|
| 62 |
FALLBACK_ENABLED = os.environ.get("FALLBACK_ENABLED", "true").lower() == "true"
|
| 63 |
FALLBACK_PRIMARY = os.environ.get("FALLBACK_PRIMARY", "nim")
|
| 64 |
FALLBACK_SECONDARY = os.environ.get("FALLBACK_SECONDARY", "cloudflare")
|
|
|
|
| 65 |
|
| 66 |
# MLX (local Apple Silicon)
|
| 67 |
MLX_API_BASE = os.environ.get("MLX_API_BASE", "http://localhost:8000/v1")
|
|
@@ -304,6 +311,7 @@ class CircuitBreaker:
|
|
| 304 |
class FallbackConfig:
|
| 305 |
primary: str = "nim"
|
| 306 |
secondary: str = "cloudflare"
|
|
|
|
| 307 |
enabled: bool = True
|
| 308 |
|
| 309 |
class FallbackManager:
|
|
@@ -338,18 +346,33 @@ class FallbackManager:
|
|
| 338 |
to_provider=self.config.secondary,
|
| 339 |
reason="circuit_open",
|
| 340 |
).inc()
|
| 341 |
-
logger.warning(
|
|
|
|
|
|
|
| 342 |
return self.config.secondary, self._get_provider_config(self.config.secondary)
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
if MLX_ENABLED:
|
| 345 |
mlx_breaker = CircuitBreaker(self.redis, "mlx")
|
| 346 |
if await mlx_breaker.can_execute():
|
| 347 |
FALLBACK_COUNT.labels(
|
| 348 |
-
from_provider=self.config.
|
| 349 |
to_provider="mlx",
|
| 350 |
-
reason="
|
| 351 |
).inc()
|
| 352 |
-
logger.warning("
|
| 353 |
return "mlx", self._get_provider_config("mlx")
|
| 354 |
|
| 355 |
raise HTTPException(status_code=503, detail="All LLM providers unavailable.")
|
|
@@ -370,6 +393,13 @@ class FallbackManager:
|
|
| 370 |
"cost_per_1m_input": 0.0,
|
| 371 |
"cost_per_1m_output": 0.0,
|
| 372 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
"mlx": {
|
| 374 |
"api_base": MLX_API_BASE,
|
| 375 |
"api_key": "no-key",
|
|
@@ -419,7 +449,7 @@ class ConcurrencyLimiter:
|
|
| 419 |
# ---------------------------------------------------------------------------
|
| 420 |
|
| 421 |
class ChatRequest(BaseModel):
|
| 422 |
-
model: str = Field(..., description="Model ID (e.g.,
|
| 423 |
messages: list[dict] = Field(..., description="OpenAI-compatible messages")
|
| 424 |
temperature: Optional[float] = 0.7
|
| 425 |
max_tokens: Optional[int] = 4096
|
|
@@ -540,7 +570,7 @@ async def _init_schema():
|
|
| 540 |
|
| 541 |
app = FastAPI(
|
| 542 |
title="ml-intern Production API",
|
| 543 |
-
description="Production-grade API with NIM/Cloudflare fallback and MLX local support",
|
| 544 |
version="1.0.0",
|
| 545 |
lifespan=lifespan,
|
| 546 |
)
|
|
@@ -715,7 +745,7 @@ async def health_check():
|
|
| 715 |
|
| 716 |
circuits = {}
|
| 717 |
fallback_status = {}
|
| 718 |
-
for provider in ["nim", "cloudflare", "mlx"]:
|
| 719 |
try:
|
| 720 |
state = await redis_manager.get_circuit_state(provider)
|
| 721 |
circuits[provider] = state["state"]
|
|
@@ -749,7 +779,6 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
|
|
| 749 |
|
| 750 |
await concurrency_limiter.acquire()
|
| 751 |
try:
|
| 752 |
-
# 1. Determine provider
|
| 753 |
if request.provider_override:
|
| 754 |
provider = request.provider_override
|
| 755 |
provider_config = fallback_manager._get_provider_config(provider)
|
|
@@ -761,7 +790,6 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
|
|
| 761 |
|
| 762 |
request.state.provider = provider
|
| 763 |
|
| 764 |
-
# 2. Rate limiting
|
| 765 |
rpm = provider_config.get("rpm_limit", DEFAULT_RPM_LIMIT)
|
| 766 |
rate_limit_key = f"{provider}:{session_id}"
|
| 767 |
allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm)
|
|
@@ -769,7 +797,6 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
|
|
| 769 |
logger.warning(f"Rate limit exceeded for {rate_limit_key}")
|
| 770 |
raise HTTPException(status_code=429, detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s", headers={"Retry-After": str(int(retry_after))})
|
| 771 |
|
| 772 |
-
# 3. Check cache
|
| 773 |
if not request.stream:
|
| 774 |
cache_key = generate_cache_key(request)
|
| 775 |
cached = await redis_manager.get_cache(cache_key)
|
|
@@ -790,19 +817,15 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
|
|
| 790 |
fallback_used=False,
|
| 791 |
)
|
| 792 |
|
| 793 |
-
# 4. Budget
|
| 794 |
cost_tracker = CostTracker(session_id, provider=provider, model=request.model)
|
| 795 |
|
| 796 |
-
# 5. Call LLM
|
| 797 |
response = await call_llm(provider, provider_config, request, session_id)
|
| 798 |
|
| 799 |
-
# Record success
|
| 800 |
breaker = CircuitBreaker(redis_manager, provider)
|
| 801 |
await breaker.record_success()
|
| 802 |
|
| 803 |
cost_tracker.record_spend(response.cost_usd)
|
| 804 |
|
| 805 |
-
# 6. Cache response
|
| 806 |
if not request.stream:
|
| 807 |
cache_key = generate_cache_key(request)
|
| 808 |
await redis_manager.set_cache(
|
|
@@ -815,11 +838,9 @@ async def chat_completions(request: ChatRequest, background_tasks: BackgroundTas
|
|
| 815 |
}),
|
| 816 |
)
|
| 817 |
|
| 818 |
-
# 7. Persist
|
| 819 |
if DATABASE_URL:
|
| 820 |
background_tasks.add_task(_persist_request, session_id, request, response, provider)
|
| 821 |
|
| 822 |
-
# Mark fallback
|
| 823 |
if provider != FALLBACK_PRIMARY and FALLBACK_ENABLED:
|
| 824 |
response.fallback_used = True
|
| 825 |
|
|
@@ -875,12 +896,24 @@ async def list_models():
|
|
| 875 |
{"id": "cloudflare/@cf/meta/llama-3.1-70b-instruct", "object": "model", "owned_by": "cloudflare"},
|
| 876 |
{"id": "cloudflare/@cf/mistral/mistral-7b-instruct", "object": "model", "owned_by": "cloudflare"},
|
| 877 |
{"id": "cloudflare/@cf/qwen/qwen1.5-14b-chat-awq", "object": "model", "owned_by": "cloudflare"},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
])
|
| 879 |
|
| 880 |
if MLX_ENABLED:
|
| 881 |
models.extend([
|
| 882 |
{"id": "mlx/llama-3.1-8b", "object": "model", "owned_by": "mlx"},
|
| 883 |
{"id": "mlx/llama-3.1-70b", "object": "model", "owned_by": "mlx"},
|
|
|
|
|
|
|
|
|
|
| 884 |
])
|
| 885 |
|
| 886 |
return {"object": "list", "data": models}
|
|
@@ -888,7 +921,7 @@ async def list_models():
|
|
| 888 |
@app.get("/v1/fallback/status")
|
| 889 |
async def fallback_status():
|
| 890 |
status = {}
|
| 891 |
-
for provider in ["nim", "cloudflare", "mlx"]:
|
| 892 |
breaker = CircuitBreaker(redis_manager, provider)
|
| 893 |
can_execute = await breaker.can_execute()
|
| 894 |
state = await redis_manager.get_circuit_state(provider)
|
|
@@ -903,6 +936,7 @@ async def fallback_status():
|
|
| 903 |
"fallback_enabled": FALLBACK_ENABLED,
|
| 904 |
"primary": FALLBACK_PRIMARY,
|
| 905 |
"secondary": FALLBACK_SECONDARY,
|
|
|
|
| 906 |
"providers": status,
|
| 907 |
"active_provider": await _get_active_provider_name(),
|
| 908 |
}
|
|
|
|
| 5 |
- FastAPI with async endpoints
|
| 6 |
- Distributed rate limiting (Redis-backed token bucket)
|
| 7 |
- Circuit breaker for external API resilience
|
| 8 |
+
- Automatic fallback: NIM (primary) -> Cloudflare Workers AI (secondary) -> Gemini (tertiary) -> MLX
|
| 9 |
- Request/response caching with Redis TTL
|
| 10 |
- Multi-tenant session isolation
|
| 11 |
- Health checks and graceful shutdown
|
|
|
|
| 13 |
- Cost tracking and budget enforcement
|
| 14 |
- Connection pooling for all HTTP clients
|
| 15 |
- Cloudflare Workers AI support via OpenAI-compatible API
|
| 16 |
+
- Google Gemini support via OpenAI-compatible API
|
| 17 |
+
- MLX local support for Apple Silicon
|
| 18 |
"""
|
| 19 |
|
| 20 |
import asyncio
|
|
|
|
| 60 |
CLOUDFLARE_API_KEY = os.environ.get("CLOUDFLARE_API_KEY", "")
|
| 61 |
CLOUDFLARE_ACCOUNT_ID = os.environ.get("CLOUDFLARE_ACCOUNT_ID", "")
|
| 62 |
|
| 63 |
+
# Google Gemini / AI Studio
|
| 64 |
+
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
|
| 65 |
+
GEMINI_API_BASE = os.environ.get("GEMINI_API_BASE", "https://generativelanguage.googleapis.com/v1beta/openai")
|
| 66 |
+
|
| 67 |
# Fallback configuration
|
| 68 |
FALLBACK_ENABLED = os.environ.get("FALLBACK_ENABLED", "true").lower() == "true"
|
| 69 |
FALLBACK_PRIMARY = os.environ.get("FALLBACK_PRIMARY", "nim")
|
| 70 |
FALLBACK_SECONDARY = os.environ.get("FALLBACK_SECONDARY", "cloudflare")
|
| 71 |
+
FALLBACK_TERTIARY = os.environ.get("FALLBACK_TERTIARY", "gemini")
|
| 72 |
|
| 73 |
# MLX (local Apple Silicon)
|
| 74 |
MLX_API_BASE = os.environ.get("MLX_API_BASE", "http://localhost:8000/v1")
|
|
|
|
| 311 |
class FallbackConfig:
|
| 312 |
primary: str = "nim"
|
| 313 |
secondary: str = "cloudflare"
|
| 314 |
+
tertiary: str = "gemini"
|
| 315 |
enabled: bool = True
|
| 316 |
|
| 317 |
class FallbackManager:
|
|
|
|
| 346 |
to_provider=self.config.secondary,
|
| 347 |
reason="circuit_open",
|
| 348 |
).inc()
|
| 349 |
+
logger.warning(
|
| 350 |
+
f"Fallback: {self.config.primary} unavailable, switching to {self.config.secondary}"
|
| 351 |
+
)
|
| 352 |
return self.config.secondary, self._get_provider_config(self.config.secondary)
|
| 353 |
|
| 354 |
+
tertiary_breaker = CircuitBreaker(self.redis, self.config.tertiary)
|
| 355 |
+
if await tertiary_breaker.can_execute():
|
| 356 |
+
FALLBACK_COUNT.labels(
|
| 357 |
+
from_provider=self.config.secondary,
|
| 358 |
+
to_provider=self.config.tertiary,
|
| 359 |
+
reason="secondary_down",
|
| 360 |
+
).inc()
|
| 361 |
+
logger.warning(
|
| 362 |
+
f"Fallback: both {self.config.primary} and {self.config.secondary} down, "
|
| 363 |
+
f"switching to {self.config.tertiary}"
|
| 364 |
+
)
|
| 365 |
+
return self.config.tertiary, self._get_provider_config(self.config.tertiary)
|
| 366 |
+
|
| 367 |
if MLX_ENABLED:
|
| 368 |
mlx_breaker = CircuitBreaker(self.redis, "mlx")
|
| 369 |
if await mlx_breaker.can_execute():
|
| 370 |
FALLBACK_COUNT.labels(
|
| 371 |
+
from_provider=self.config.tertiary,
|
| 372 |
to_provider="mlx",
|
| 373 |
+
reason="all_cloud_down",
|
| 374 |
).inc()
|
| 375 |
+
logger.warning("All cloud providers down — falling back to MLX local")
|
| 376 |
return "mlx", self._get_provider_config("mlx")
|
| 377 |
|
| 378 |
raise HTTPException(status_code=503, detail="All LLM providers unavailable.")
|
|
|
|
| 393 |
"cost_per_1m_input": 0.0,
|
| 394 |
"cost_per_1m_output": 0.0,
|
| 395 |
},
|
| 396 |
+
"gemini": {
|
| 397 |
+
"api_base": GEMINI_API_BASE,
|
| 398 |
+
"api_key": GEMINI_API_KEY,
|
| 399 |
+
"rpm_limit": 60,
|
| 400 |
+
"cost_per_1m_input": 0.075,
|
| 401 |
+
"cost_per_1m_output": 0.30,
|
| 402 |
+
},
|
| 403 |
"mlx": {
|
| 404 |
"api_base": MLX_API_BASE,
|
| 405 |
"api_key": "no-key",
|
|
|
|
| 449 |
# ---------------------------------------------------------------------------
|
| 450 |
|
| 451 |
class ChatRequest(BaseModel):
|
| 452 |
+
model: str = Field(..., description="Model ID (e.g., gemma-4-31b-bf16)")
|
| 453 |
messages: list[dict] = Field(..., description="OpenAI-compatible messages")
|
| 454 |
temperature: Optional[float] = 0.7
|
| 455 |
max_tokens: Optional[int] = 4096
|
|
|
|
| 570 |
|
| 571 |
app = FastAPI(
|
| 572 |
title="ml-intern Production API",
|
| 573 |
+
description="Production-grade API with NIM/Cloudflare/Gemini fallback and MLX local support",
|
| 574 |
version="1.0.0",
|
| 575 |
lifespan=lifespan,
|
| 576 |
)
|
|
|
|
| 745 |
|
| 746 |
circuits = {}
|
| 747 |
fallback_status = {}
|
| 748 |
+
for provider in ["nim", "cloudflare", "gemini", "mlx"]:
|
| 749 |
try:
|
| 750 |
state = await redis_manager.get_circuit_state(provider)
|
| 751 |
circuits[provider] = state["state"]
|
|
|
|
| 779 |
|
| 780 |
await concurrency_limiter.acquire()
|
| 781 |
try:
|
|
|
|
| 782 |
if request.provider_override:
|
| 783 |
provider = request.provider_override
|
| 784 |
provider_config = fallback_manager._get_provider_config(provider)
|
|
|
|
| 790 |
|
| 791 |
request.state.provider = provider
|
| 792 |
|
|
|
|
| 793 |
rpm = provider_config.get("rpm_limit", DEFAULT_RPM_LIMIT)
|
| 794 |
rate_limit_key = f"{provider}:{session_id}"
|
| 795 |
allowed, retry_after = await redis_manager.check_rate_limit(rate_limit_key, rpm)
|
|
|
|
| 797 |
logger.warning(f"Rate limit exceeded for {rate_limit_key}")
|
| 798 |
raise HTTPException(status_code=429, detail=f"Rate limit exceeded. Retry after {retry_after:.1f}s", headers={"Retry-After": str(int(retry_after))})
|
| 799 |
|
|
|
|
| 800 |
if not request.stream:
|
| 801 |
cache_key = generate_cache_key(request)
|
| 802 |
cached = await redis_manager.get_cache(cache_key)
|
|
|
|
| 817 |
fallback_used=False,
|
| 818 |
)
|
| 819 |
|
|
|
|
| 820 |
cost_tracker = CostTracker(session_id, provider=provider, model=request.model)
|
| 821 |
|
|
|
|
| 822 |
response = await call_llm(provider, provider_config, request, session_id)
|
| 823 |
|
|
|
|
| 824 |
breaker = CircuitBreaker(redis_manager, provider)
|
| 825 |
await breaker.record_success()
|
| 826 |
|
| 827 |
cost_tracker.record_spend(response.cost_usd)
|
| 828 |
|
|
|
|
| 829 |
if not request.stream:
|
| 830 |
cache_key = generate_cache_key(request)
|
| 831 |
await redis_manager.set_cache(
|
|
|
|
| 838 |
}),
|
| 839 |
)
|
| 840 |
|
|
|
|
| 841 |
if DATABASE_URL:
|
| 842 |
background_tasks.add_task(_persist_request, session_id, request, response, provider)
|
| 843 |
|
|
|
|
| 844 |
if provider != FALLBACK_PRIMARY and FALLBACK_ENABLED:
|
| 845 |
response.fallback_used = True
|
| 846 |
|
|
|
|
| 896 |
{"id": "cloudflare/@cf/meta/llama-3.1-70b-instruct", "object": "model", "owned_by": "cloudflare"},
|
| 897 |
{"id": "cloudflare/@cf/mistral/mistral-7b-instruct", "object": "model", "owned_by": "cloudflare"},
|
| 898 |
{"id": "cloudflare/@cf/qwen/qwen1.5-14b-chat-awq", "object": "model", "owned_by": "cloudflare"},
|
| 899 |
+
{"id": "cloudflare/@cf/google/gemma-4-26b-a4b-it", "object": "model", "owned_by": "cloudflare"},
|
| 900 |
+
])
|
| 901 |
+
|
| 902 |
+
if GEMINI_API_KEY:
|
| 903 |
+
models.extend([
|
| 904 |
+
{"id": "gemini/gemini-2.5-pro-preview", "object": "model", "owned_by": "google"},
|
| 905 |
+
{"id": "gemini/gemini-2.5-flash-preview", "object": "model", "owned_by": "google"},
|
| 906 |
+
{"id": "gemini/gemma-4-26b", "object": "model", "owned_by": "google"},
|
| 907 |
+
{"id": "gemini/gemma-4-9b", "object": "model", "owned_by": "google"},
|
| 908 |
])
|
| 909 |
|
| 910 |
if MLX_ENABLED:
|
| 911 |
models.extend([
|
| 912 |
{"id": "mlx/llama-3.1-8b", "object": "model", "owned_by": "mlx"},
|
| 913 |
{"id": "mlx/llama-3.1-70b", "object": "model", "owned_by": "mlx"},
|
| 914 |
+
{"id": "mlx/gemma-4-26b-a4b-it", "object": "model", "owned_by": "mlx"},
|
| 915 |
+
{"id": "mlx/gemma-4-31b-bf16", "object": "model", "owned_by": "mlx"},
|
| 916 |
+
{"id": "mlx/gemma-4-e4b-it", "object": "model", "owned_by": "mlx"},
|
| 917 |
])
|
| 918 |
|
| 919 |
return {"object": "list", "data": models}
|
|
|
|
| 921 |
@app.get("/v1/fallback/status")
|
| 922 |
async def fallback_status():
|
| 923 |
status = {}
|
| 924 |
+
for provider in ["nim", "cloudflare", "gemini", "mlx"]:
|
| 925 |
breaker = CircuitBreaker(redis_manager, provider)
|
| 926 |
can_execute = await breaker.can_execute()
|
| 927 |
state = await redis_manager.get_circuit_state(provider)
|
|
|
|
| 936 |
"fallback_enabled": FALLBACK_ENABLED,
|
| 937 |
"primary": FALLBACK_PRIMARY,
|
| 938 |
"secondary": FALLBACK_SECONDARY,
|
| 939 |
+
"tertiary": FALLBACK_TERTIARY,
|
| 940 |
"providers": status,
|
| 941 |
"active_provider": await _get_active_provider_name(),
|
| 942 |
}
|