proxycf / proxy_cerebras.py
Elysiadev11's picture
Update proxy_cerebras.py
4ba202a verified
import os
import json
import time
import uuid
import asyncio
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.requests import ClientDisconnect
app = FastAPI()
# =====================================================
# CONFIG
# =====================================================
MASTER_API_KEY = os.getenv("MASTER_API_KEY", "olla")
# Default CF Workers AI model (can override via request body)
DEFAULT_CF_MODEL = os.getenv("DEFAULT_CF_MODEL", "@cf/meta/llama-3.3-70b-instruct-fp8-fast")
# =====================================================
# LOAD CF CREDENTIALS
# Format env: CF_1=account_id,api_key
# =====================================================
CF_ACCOUNTS = [] # list of {"account_id": ..., "api_key": ...}
for i in range(1, 101):
raw = os.getenv(f"CF_{i}")
if not raw:
continue
parts = raw.split(",", 1)
if len(parts) != 2:
print(f"[WARN] CF_{i} format invalid, expected 'account_id,api_key' — skipped")
continue
account_id, api_key = parts[0].strip(), parts[1].strip()
if account_id and api_key:
CF_ACCOUNTS.append({"account_id": account_id, "api_key": api_key})
if not CF_ACCOUNTS:
print("[WARN] No CF credentials found, inserting dummy")
CF_ACCOUNTS.append({"account_id": "dummy", "api_key": "dummy"})
# =====================================================
# KEY STATUS
# =====================================================
key_status = {}
for idx, acc in enumerate(CF_ACCOUNTS, 1):
kid = acc["account_id"]
key_status[kid] = {
"index": idx,
"healthy": True,
"busy": False,
"success": 0,
"fail": 0,
}
rr_index = 0
_key_lock = asyncio.Lock()
# =====================================================
# HELPERS
# =====================================================
def log(x):
print(f"[{time.strftime('%H:%M:%S')}] {x}", flush=True)
def sse(obj):
return "data: " + json.dumps(obj, ensure_ascii=False) + "\n\n"
def auth_ok(req: Request):
token = req.headers.get("Authorization", "").replace("Bearer ", "")
return token == MASTER_API_KEY
def cf_url(account_id: str, model: str) -> str:
return f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model}"
async def get_key(exclude=None):
global rr_index
if exclude is None:
exclude = set()
async with _key_lock:
for _ in range(len(CF_ACCOUNTS)):
rr_index = (rr_index + 1) % len(CF_ACCOUNTS)
acc = CF_ACCOUNTS[rr_index]
kid = acc["account_id"]
st = key_status[kid]
if st["healthy"] and not st["busy"] and kid not in exclude:
st["busy"] = True
return acc # returns dict {"account_id": ..., "api_key": ...}
return None
async def release_key(acc):
async with _key_lock:
kid = acc["account_id"]
if kid in key_status:
key_status[kid]["busy"] = False
async def mark_fail(acc):
async with _key_lock:
kid = acc["account_id"]
if kid in key_status:
key_status[kid]["fail"] += 1
async def mark_ok(acc):
async with _key_lock:
kid = acc["account_id"]
if kid in key_status:
key_status[kid]["success"] += 1
key_status[kid]["fail"] = 0
async def wait_for_free_key(exclude=None, max_wait=30.0, interval=0.3):
elapsed = 0.0
while elapsed < max_wait:
acc = await get_key(exclude)
if acc:
return acc
await asyncio.sleep(interval)
elapsed += interval
return None
def is_rate_limited(status_code: int, text: str) -> bool:
t = text.lower()
return status_code == 429 or "rate limit" in t or "too many requests" in t or "usage limit" in t
# =====================================================
# ROOT
# =====================================================
@app.get("/")
async def root():
async with _key_lock:
safe = {}
for kid, v in key_status.items():
masked = kid[:6] + "****" + kid[-4:]
safe[masked] = {
"index": v["index"],
"healthy": v["healthy"],
"busy": v["busy"],
"success": v["success"],
"fail": v["fail"],
}
return {
"status": "ok",
"accounts": len(CF_ACCOUNTS),
"default_model": DEFAULT_CF_MODEL,
"detail": safe
}
# =====================================================
# /v1/models — static list of popular CF models
# =====================================================
@app.get("/v1/models")
async def models(req: Request):
if not auth_ok(req):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
now = int(time.time())
cf_models = [
"@cf/meta/llama-3.3-70b-instruct-fp8-fast",
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-3.1-70b-instruct",
"@cf/mistral/mistral-7b-instruct-v0.1",
"@cf/google/gemma-7b-it",
"@cf/qwen/qwen1.5-14b-chat-awq",
"@cf/deepseek-ai/deepseek-r1-distill-qwen-32b",
]
data = [
{"id": m, "object": "model", "created": now, "owned_by": "cloudflare"}
for m in cf_models
]
return {"object": "list", "data": data}
# =====================================================
# /v1/chat/completions — OpenAI-compatible endpoint
# =====================================================
@app.post("/v1/chat/completions")
async def chat(req: Request):
if not auth_ok(req):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
try:
body = await req.json()
except Exception:
return JSONResponse({"error": "Bad JSON"}, status_code=400)
is_stream = body.get("stream", False)
model = body.get("model", DEFAULT_CF_MODEL)
messages = body.get("messages", [])
max_tokens = body.get("max_tokens", 2048)
cf_body = {
"messages": messages,
"stream": is_stream,
"max_tokens": max_tokens,
}
# -----------------------------------------
# NON STREAM
# -----------------------------------------
if not is_stream:
tried = set()
for _ in range(len(CF_ACCOUNTS)):
acc = await wait_for_free_key(exclude=tried)
if not acc:
break
tried.add(acc["account_id"])
try:
async with httpx.AsyncClient(timeout=180) as client:
r = await client.post(
cf_url(acc["account_id"], model),
json=cf_body,
headers={
"Authorization": f"Bearer {acc['api_key']}",
"Content-Type": "application/json",
}
)
if is_rate_limited(r.status_code, r.text):
log(f"Account {acc['account_id'][:8]}... rate limited (non-stream), trying next")
await mark_fail(acc)
continue
if r.status_code != 200:
log(f"Account {acc['account_id'][:8]}... HTTP {r.status_code}, trying next")
await mark_fail(acc)
continue
data = r.json()
# CF Workers AI response format:
# {"result": {"response": "..."}, "success": true, ...}
# Convert to OpenAI format
cf_result = data.get("result", {})
content = cf_result.get("response", "")
out = {
"id": "chatcmpl-" + uuid.uuid4().hex[:10],
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
}
await mark_ok(acc)
return JSONResponse(out)
except Exception as e:
log(f"Account {acc['account_id'][:8]}... exception: {e}")
await mark_fail(acc)
finally:
await release_key(acc)
return JSONResponse({"error": "All accounts failed"}, status_code=500)
# -----------------------------------------
# STREAM
# CF Workers AI streams NDJSON lines:
# {"response":"token"} or {"p":"...","response":"token"} and ends with [DONE]
# We convert to OpenAI SSE format
# -----------------------------------------
async def gen():
tried = set()
cid = "chatcmpl-" + uuid.uuid4().hex[:10]
sent_any = False
for _ in range(len(CF_ACCOUNTS)):
acc = await wait_for_free_key(exclude=tried)
if not acc:
break
tried.add(acc["account_id"])
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
cf_url(acc["account_id"], model),
json=cf_body,
headers={
"Authorization": f"Bearer {acc['api_key']}",
"Content-Type": "application/json",
}
) as r:
if is_rate_limited(r.status_code, ""):
log(f"Account {acc['account_id'][:8]}... rate limited (stream), trying next")
await mark_fail(acc)
continue
if r.status_code != 200:
log(f"Account {acc['account_id'][:8]}... HTTP {r.status_code} (stream), trying next")
await mark_fail(acc)
continue
hit_limit = False
async for line in r.aiter_lines():
line = line.strip()
if not line:
continue
if line == "data: [DONE]" or line == "[DONE]":
break
# Strip "data: " prefix if present
raw = line[6:] if line.startswith("data: ") else line
# Detect mid-stream rate limit
if is_rate_limited(0, raw):
log(f"Account {acc['account_id'][:8]}... mid-stream limit, switching key")
hit_limit = True
break
try:
j = json.loads(raw)
except Exception:
continue
token = j.get("response", "")
if token:
sent_any = True
chunk = {
"id": cid,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": token},
"finish_reason": None,
}
]
}
yield sse(chunk)
if hit_limit:
await mark_fail(acc)
continue
# Send finish chunk
finish_chunk = {
"id": cid,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop",
}
]
}
yield sse(finish_chunk)
yield "data: [DONE]\n\n"
await mark_ok(acc)
return
except Exception as e:
log(f"Account {acc['account_id'][:8]}... stream exception: {e}")
await mark_fail(acc)
finally:
await release_key(acc)
yield sse({"error": "All accounts failed"})
yield "data: [DONE]\n\n"
return StreamingResponse(gen(), media_type="text/event-stream")
# =====================================================
# /v1/messages — Anthropic-compatible endpoint
# =====================================================
@app.post("/v1/messages")
async def anthropic(req: Request):
if not auth_ok(req):
return JSONResponse({"error": "Unauthorized"}, status_code=401)
try:
body = await req.json()
except ClientDisconnect:
return Response(status_code=499)
except Exception:
return JSONResponse({"error": "Bad JSON"}, status_code=400)
stream = body.get("stream", False)
model = body.get("model", DEFAULT_CF_MODEL)
max_tokens = body.get("max_tokens", 2048)
# Convert Anthropic message format to CF/OpenAI format
messages = []
if body.get("system"):
messages.append({"role": "system", "content": body["system"]})
for m in body.get("messages", []):
content = m.get("content", "")
if isinstance(content, list):
txt = ""
for x in content:
if x.get("type") == "text":
txt += x.get("text", "")
content = txt
messages.append({"role": m["role"], "content": content})
cf_body = {
"messages": messages,
"stream": stream,
"max_tokens": max_tokens,
}
# -----------------------------------------
# NON STREAM
# -----------------------------------------
if not stream:
tried = set()
for _ in range(len(CF_ACCOUNTS)):
acc = await wait_for_free_key(exclude=tried)
if not acc:
break
tried.add(acc["account_id"])
try:
async with httpx.AsyncClient(timeout=180) as client:
r = await client.post(
cf_url(acc["account_id"], model),
json=cf_body,
headers={
"Authorization": f"Bearer {acc['api_key']}",
"Content-Type": "application/json",
}
)
if is_rate_limited(r.status_code, r.text):
log(f"Account {acc['account_id'][:8]}... rate limited (anthropic non-stream), trying next")
await mark_fail(acc)
continue
if r.status_code != 200:
log(f"Account {acc['account_id'][:8]}... HTTP {r.status_code}, trying next")
await mark_fail(acc)
continue
data = r.json()
cf_result = data.get("result", {})
content = cf_result.get("response", "")
out = {
"id": "msg_" + uuid.uuid4().hex[:10],
"type": "message",
"role": "assistant",
"model": body.get("model", DEFAULT_CF_MODEL),
"content": [{"type": "text", "text": content}],
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0}
}
await mark_ok(acc)
return JSONResponse(out)
except Exception as e:
log(f"Account {acc['account_id'][:8]}... exception: {e}")
await mark_fail(acc)
finally:
await release_key(acc)
return JSONResponse({"error": "All accounts failed"}, status_code=500)
# -----------------------------------------
# STREAM (Anthropic SSE envelope)
# -----------------------------------------
async def agen():
tried = set()
msg_id = "msg_" + uuid.uuid4().hex[:10]
sent_any_delta = False
for _ in range(len(CF_ACCOUNTS)):
acc = await wait_for_free_key(exclude=tried)
if not acc:
break
tried.add(acc["account_id"])
try:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream(
"POST",
cf_url(acc["account_id"], model),
json=cf_body,
headers={
"Authorization": f"Bearer {acc['api_key']}",
"Content-Type": "application/json",
}
) as r:
if is_rate_limited(r.status_code, ""):
log(f"Account {acc['account_id'][:8]}... rate limited (anthropic stream), trying next")
await mark_fail(acc)
continue
if r.status_code != 200:
log(f"Account {acc['account_id'][:8]}... HTTP {r.status_code} (anthropic stream), trying next")
await mark_fail(acc)
continue
# Emit Anthropic envelope only once on first successful key
if not sent_any_delta:
yield sse({
"type": "message_start",
"message": {
"id": msg_id,
"type": "message",
"role": "assistant",
"model": body.get("model", DEFAULT_CF_MODEL),
"content": [],
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0}
}
})
yield sse({
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text"}
})
hit_limit = False
async for line in r.aiter_lines():
line = line.strip()
if not line:
continue
if line == "data: [DONE]" or line == "[DONE]":
break
raw = line[6:] if line.startswith("data: ") else line
if is_rate_limited(0, raw):
log(f"Account {acc['account_id'][:8]}... mid-stream limit (anthr