Elysiadev11 commited on
Commit
529d4d4
·
verified ·
1 Parent(s): 79ff08c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +792 -0
app.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cerebras Proxy Server
3
+ - OpenAI-compatible endpoint: /v1/chat/completions
4
+ - Anthropic-compatible endpoint: /v1/messages
5
+ - Token limiting: Max 30,000 request tokens (auto-truncate oldest messages)
6
+ - Multi-key round-robin with failover
7
+ """
8
+
9
+ import os
10
+ import json
11
+ import time
12
+ import uuid
13
+ import asyncio
14
+ import httpx
15
+ import tiktoken
16
+
17
+ from fastapi import FastAPI, Request
18
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
19
+ from starlette.requests import ClientDisconnect
20
+
21
+ app = FastAPI()
22
+
23
+ # =====================================================
24
+ # CONFIG
25
+ # =====================================================
26
+ MASTER_API_KEY = os.getenv("MASTER_API_KEY", "olla")
27
+ CEREBRAS_BASE_URL = os.getenv("CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1")
28
+ MAX_REQUEST_TOKENS = int(os.getenv("MAX_REQUEST_TOKENS", "30000"))
29
+
30
+ # Default model for Cerebras
31
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "llama-4-scout-17b-16e-instruct")
32
+
33
+ # Model mapping: incoming model name -> Cerebras model name
34
+ DEFAULT_MODEL_MAPPING = {
35
+ # Claude models -> Cerebras
36
+ "claude-opus-4-7": "llama-4-scout-17b-16e-instruct",
37
+ "claude-opus-4-6": "llama-4-scout-17b-16e-instruct",
38
+ "claude-opus-4-5": "llama-4-scout-17b-16e-instruct",
39
+ "claude-opus-4-1": "llama-4-scout-17b-16e-instruct",
40
+ "claude-opus-4-20250514": "llama-4-scout-17b-16e-instruct",
41
+ "claude-sonnet-4-6": "llama-4-scout-17b-16e-instruct",
42
+ "claude-sonnet-4-5": "llama-4-scout-17b-16e-instruct",
43
+ "claude-sonnet-4-20250514": "llama-4-scout-17b-16e-instruct",
44
+ "claude-haiku-4-5": "llama-4-scout-17b-16e-instruct",
45
+ "claude-haiku-4-5-20251001": "llama-4-scout-17b-16e-instruct",
46
+ # GPT models -> Cerebras
47
+ "gpt-4": "llama-4-scout-17b-16e-instruct",
48
+ "gpt-4o": "llama-4-scout-17b-16e-instruct",
49
+ "gpt-4o-mini": "llama-4-scout-17b-16e-instruct",
50
+ "gpt-4-turbo": "llama-4-scout-17b-16e-instruct",
51
+ "gpt-3.5-turbo": "llama-4-scout-17b-16e-instruct",
52
+ }
53
+
54
+ def load_model_mapping():
55
+ mapping = DEFAULT_MODEL_MAPPING.copy()
56
+ env_map = os.getenv("MODEL_MAP")
57
+ if env_map:
58
+ for pair in env_map.split(","):
59
+ if ":" in pair:
60
+ parts = pair.split(":", 1)
61
+ if len(parts) == 2:
62
+ mapping[parts[0].strip()] = parts[1].strip()
63
+ return mapping
64
+
65
+ def map_model(model_name: str) -> str:
66
+ mapping = load_model_mapping()
67
+ if model_name in mapping:
68
+ return mapping[model_name]
69
+ # If model is already a Cerebras model, pass through
70
+ return model_name
71
+
72
+ # =====================================================
73
+ # API KEYS - Load from env: CEREBRAS_KEY_1, CEREBRAS_KEY_2, ...
74
+ # =====================================================
75
+ API_KEYS = []
76
+
77
+ for i in range(1, 101):
78
+ key = os.getenv(f"CEREBRAS_KEY_{i}")
79
+ if key:
80
+ API_KEYS.append(key)
81
+
82
+ if not API_KEYS:
83
+ # Fallback: check CEREBRAS_API_KEY
84
+ fallback = os.getenv("CEREBRAS_API_KEY", "")
85
+ if fallback:
86
+ API_KEYS.append(fallback)
87
+ else:
88
+ API_KEYS.append("dummy_key")
89
+
90
+ # =====================================================
91
+ # KEY STATUS & ROUND ROBIN
92
+ # =====================================================
93
+ key_status = {}
94
+ for idx, k in enumerate(API_KEYS, 1):
95
+ key_status[k] = {
96
+ "index": idx,
97
+ "prefix": k[:8] + "..." if len(k) > 8 else k,
98
+ "healthy": True,
99
+ "busy": False,
100
+ "success": 0,
101
+ "fail": 0,
102
+ }
103
+
104
+ rr_index = 0
105
+ _key_lock = asyncio.Lock()
106
+
107
+ # =====================================================
108
+ # TOKEN COUNTING
109
+ # =====================================================
110
+ # Use cl100k_base (GPT-4 tokenizer) as a reasonable approximation
111
+ try:
112
+ _encoder = tiktoken.get_encoding("cl100k_base")
113
+ except Exception:
114
+ _encoder = None
115
+
116
+ def count_tokens(text: str) -> int:
117
+ """Count tokens in text using tiktoken, fallback to char/4 estimate."""
118
+ if _encoder is None:
119
+ return len(text) // 4
120
+ return len(_encoder.encode(text, disallowed_special=()))
121
+
122
+ def count_messages_tokens(messages: list) -> int:
123
+ """Count total tokens in a list of messages."""
124
+ total = 0
125
+ for msg in messages:
126
+ content = msg.get("content", "")
127
+ if isinstance(content, list):
128
+ for block in content:
129
+ if isinstance(block, dict) and block.get("type") == "text":
130
+ total += count_tokens(block.get("text", ""))
131
+ elif isinstance(content, str):
132
+ total += count_tokens(content)
133
+ # Add overhead for role, etc.
134
+ total += 4 # ~4 tokens per message overhead
135
+ return total
136
+
137
+ def truncate_messages(messages: list, max_tokens: int) -> list:
138
+ """
139
+ Truncate messages to fit within max_tokens.
140
+ Strategy:
141
+ 1. Always keep the system message (first message if role=system)
142
+ 2. Always keep the last user message
143
+ 3. Remove oldest non-system messages first
144
+ 4. If still over limit, truncate the content of remaining messages
145
+ """
146
+ if not messages:
147
+ return messages
148
+
149
+ total = count_messages_tokens(messages)
150
+ if total <= max_tokens:
151
+ return messages
152
+
153
+ log(f"⚠️ Token count {total} exceeds limit {max_tokens}. Truncating...")
154
+
155
+ # Separate system message from others
156
+ system_msgs = []
157
+ other_msgs = []
158
+
159
+ for msg in messages:
160
+ if msg.get("role") == "system":
161
+ system_msgs.append(msg)
162
+ else:
163
+ other_msgs.append(msg)
164
+
165
+ # Always keep last message (usually the latest user message)
166
+ if not other_msgs:
167
+ return messages
168
+
169
+ last_msg = other_msgs[-1]
170
+ middle_msgs = other_msgs[:-1]
171
+
172
+ # Try removing middle messages from oldest first
173
+ result = system_msgs.copy()
174
+ remaining_budget = max_tokens - count_messages_tokens(system_msgs) - count_messages_tokens([last_msg])
175
+
176
+ if remaining_budget < 0:
177
+ # Even system + last msg exceeds limit
178
+ # Truncate system message content
179
+ if system_msgs:
180
+ sys_content = system_msgs[0].get("content", "")
181
+ if isinstance(sys_content, str):
182
+ # Keep only first 2000 tokens of system prompt
183
+ max_sys = min(2000, max_tokens // 4)
184
+ if _encoder:
185
+ tokens = _encoder.encode(sys_content, disallowed_special=())
186
+ if len(tokens) > max_sys:
187
+ sys_content = _encoder.decode(tokens[:max_sys])
188
+ else:
189
+ sys_content = sys_content[:max_sys * 4]
190
+ system_msgs[0] = {**system_msgs[0], "content": sys_content}
191
+
192
+ # Truncate last message content too if needed
193
+ last_content = last_msg.get("content", "")
194
+ if isinstance(last_content, str):
195
+ max_last = max_tokens - count_messages_tokens(system_msgs) - 10
196
+ if max_last > 0:
197
+ last_tokens = count_tokens(last_content)
198
+ if last_tokens > max_last:
199
+ if _encoder:
200
+ tokens = _encoder.encode(last_content, disallowed_special=())
201
+ last_content = _encoder.decode(tokens[:max_last])
202
+ else:
203
+ last_content = last_content[:max_last * 4]
204
+ last_msg = {**last_msg, "content": last_content}
205
+
206
+ result = system_msgs + [last_msg]
207
+ final_count = count_messages_tokens(result)
208
+ log(f"✂️ Truncated to {final_count} tokens (heavy truncation)")
209
+ return result
210
+
211
+ # Add middle messages from newest to oldest until budget exhausted
212
+ kept_middle = []
213
+ for msg in reversed(middle_msgs):
214
+ msg_tokens = count_messages_tokens([msg])
215
+ if remaining_budget >= msg_tokens:
216
+ kept_middle.insert(0, msg)
217
+ remaining_budget -= msg_tokens
218
+ else:
219
+ # Try to fit a truncated version
220
+ if remaining_budget > 50: # Only bother if we have meaningful space
221
+ content = msg.get("content", "")
222
+ if isinstance(content, str) and remaining_budget > 10:
223
+ if _encoder:
224
+ tokens = _encoder.encode(content, disallowed_special=())
225
+ truncated = _encoder.decode(tokens[:remaining_budget - 10])
226
+ else:
227
+ truncated = content[:(remaining_budget - 10) * 4]
228
+ kept_middle.insert(0, {**msg, "content": truncated + "\n[...truncated]"})
229
+ remaining_budget = 0
230
+ break
231
+
232
+ result = system_msgs + kept_middle + [last_msg]
233
+ final_count = count_messages_tokens(result)
234
+ removed = len(middle_msgs) - len(kept_middle)
235
+ log(f"✂️ Truncated: removed {removed} messages, final {final_count} tokens")
236
+ return result
237
+
238
+
239
+ # =====================================================
240
+ # UTILITY
241
+ # =====================================================
242
+ def log(msg):
243
+ print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
244
+
245
+ def sse(obj):
246
+ return "data: " + json.dumps(obj, ensure_ascii=False) + "\n\n"
247
+
248
+ def auth_ok(req: Request):
249
+ token = req.headers.get("Authorization", "").replace("Bearer ", "")
250
+ return token == MASTER_API_KEY
251
+
252
+ async def get_key(exclude=None):
253
+ global rr_index
254
+ if exclude is None:
255
+ exclude = set()
256
+
257
+ async with _key_lock:
258
+ # Check if all keys are unhealthy, reset if so
259
+ if not any(v["healthy"] for v in key_status.values()):
260
+ log("⚠️ All API Keys unhealthy. Resetting all...")
261
+ for v in key_status.values():
262
+ v["fail"] = 0
263
+ v["healthy"] = True
264
+
265
+ for _ in range(len(API_KEYS)):
266
+ rr_index = (rr_index + 1) % len(API_KEYS)
267
+ key = API_KEYS[rr_index]
268
+ st = key_status[key]
269
+
270
+ if st["healthy"] and not st["busy"] and key not in exclude:
271
+ st["busy"] = True
272
+ return key
273
+
274
+ return None
275
+
276
+ async def release_key(key):
277
+ async with _key_lock:
278
+ if key in key_status:
279
+ key_status[key]["busy"] = False
280
+
281
+ async def mark_fail(key):
282
+ async with _key_lock:
283
+ if key in key_status:
284
+ key_status[key]["fail"] += 1
285
+ if key_status[key]["fail"] >= 3:
286
+ key_status[key]["healthy"] = False
287
+
288
+ async def mark_ok(key):
289
+ async with _key_lock:
290
+ if key in key_status:
291
+ key_status[key]["success"] += 1
292
+ key_status[key]["fail"] = 0
293
+ key_status[key]["healthy"] = True
294
+
295
+ async def wait_for_free_key(exclude=None, max_wait=60.0, interval=0.3):
296
+ elapsed = 0.0
297
+ while elapsed < max_wait:
298
+ key = await get_key(exclude)
299
+ if key:
300
+ return key
301
+ await asyncio.sleep(interval)
302
+ elapsed += interval
303
+ return None
304
+
305
+ def is_rate_limited(status_code: int, text: str) -> bool:
306
+ t = text.lower()
307
+ return status_code == 429 or "rate limit" in t or "too many requests" in t or "usage limit" in t
308
+
309
+ # =====================================================
310
+ # ROOT / STATUS
311
+ # =====================================================
312
+ @app.get("/")
313
+ async def root():
314
+ async with _key_lock:
315
+ keys_info = {}
316
+ for k, v in key_status.items():
317
+ keys_info[v["prefix"]] = {
318
+ "status": "BUSY" if v["busy"] else "IDLE",
319
+ "healthy": v["healthy"],
320
+ "success": v["success"],
321
+ "fail": v["fail"],
322
+ }
323
+
324
+ return {
325
+ "status": "ok",
326
+ "backend": "cerebras",
327
+ "base_url": CEREBRAS_BASE_URL,
328
+ "default_model": DEFAULT_MODEL,
329
+ "max_request_tokens": MAX_REQUEST_TOKENS,
330
+ "total_keys": len(API_KEYS),
331
+ "keys": keys_info,
332
+ }
333
+
334
+ # =====================================================
335
+ # /v1/models
336
+ # =====================================================
337
+ @app.get("/v1/models")
338
+ async def list_models(req: Request):
339
+ if not auth_ok(req):
340
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
341
+
342
+ # Try to fetch from Cerebras API
343
+ key = API_KEYS[0] if API_KEYS else ""
344
+ try:
345
+ async with httpx.AsyncClient(timeout=30) as client:
346
+ r = await client.get(
347
+ f"{CEREBRAS_BASE_URL}/models",
348
+ headers={"Authorization": f"Bearer {key}"}
349
+ )
350
+ if r.status_code == 200:
351
+ return Response(content=r.content, media_type="application/json")
352
+ except Exception as e:
353
+ log(f"[/v1/models] Error fetching from Cerebras: {e}")
354
+
355
+ # Fallback: return known models
356
+ now = int(time.time())
357
+ known_models = [
358
+ "llama-4-scout-17b-16e-instruct",
359
+ "llama-4-maverick-17b-128e-instruct",
360
+ "llama3.3-70b",
361
+ "llama3.1-8b",
362
+ "qwen-3-32b",
363
+ "deepseek-r1-distill-llama-70b",
364
+ ]
365
+ data = [
366
+ {"id": m, "object": "model", "created": now, "owned_by": "cerebras"}
367
+ for m in known_models
368
+ ]
369
+ return {"object": "list", "data": data}
370
+
371
+ # =====================================================
372
+ # /v1/chat/completions (OpenAI-compatible)
373
+ # =====================================================
374
+ @app.post("/v1/chat/completions")
375
+ async def chat(req: Request):
376
+ if not auth_ok(req):
377
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
378
+
379
+ try:
380
+ body = await req.json()
381
+ except ClientDisconnect:
382
+ log("Client disconnected before reading body.")
383
+ return Response(status_code=499)
384
+ except json.JSONDecodeError:
385
+ return JSONResponse({"error": "Invalid JSON body"}, status_code=400)
386
+
387
+ is_stream = body.get("stream", False)
388
+ original_model = body.get("model", DEFAULT_MODEL)
389
+ cerebras_model = map_model(original_model)
390
+
391
+ # Token limiting: truncate messages
392
+ messages = body.get("messages", [])
393
+ messages = truncate_messages(messages, MAX_REQUEST_TOKENS)
394
+
395
+ # Build Cerebras request body
396
+ cerebras_body = {
397
+ "model": cerebras_model,
398
+ "messages": messages,
399
+ "stream": is_stream,
400
+ }
401
+
402
+ # Forward optional parameters
403
+ for param in ["max_tokens", "max_completion_tokens", "temperature", "top_p", "stop", "frequency_penalty", "presence_penalty"]:
404
+ if param in body:
405
+ cerebras_body[param] = body[param]
406
+
407
+ # Cap max_completion_tokens to avoid blowing Cerebras limits
408
+ if "max_tokens" not in cerebras_body and "max_completion_tokens" not in cerebras_body:
409
+ cerebras_body["max_completion_tokens"] = 8192
410
+
411
+ # -----------------------------------------
412
+ # NON STREAM
413
+ # -----------------------------------------
414
+ if not is_stream:
415
+ tried = set()
416
+
417
+ for _ in range(len(API_KEYS)):
418
+ key = await wait_for_free_key(exclude=tried)
419
+ if not key:
420
+ break
421
+
422
+ tried.add(key)
423
+ ki = key_status[key]
424
+ log(f"NON-STREAM: Using key#{ki['index']}")
425
+
426
+ try:
427
+ async with httpx.AsyncClient(timeout=180) as client:
428
+ r = await client.post(
429
+ f"{CEREBRAS_BASE_URL}/chat/completions",
430
+ json=cerebras_body,
431
+ headers={
432
+ "Authorization": f"Bearer {key}",
433
+ "Content-Type": "application/json",
434
+ }
435
+ )
436
+
437
+ if is_rate_limited(r.status_code, r.text):
438
+ log(f"RATE LIMITED: key#{ki['index']}, trying next")
439
+ await mark_fail(key)
440
+ continue
441
+
442
+ if r.status_code != 200:
443
+ log(f"HTTP {r.status_code}: key#{ki['index']}, trying next")
444
+ await mark_fail(key)
445
+ continue
446
+
447
+ await mark_ok(key)
448
+
449
+ # Cerebras returns OpenAI-compatible format, forward directly
450
+ return Response(content=r.content, media_type="application/json")
451
+
452
+ except Exception as e:
453
+ log(f"Exception: key#{ki['index']} - {e}")
454
+ await mark_fail(key)
455
+
456
+ finally:
457
+ await release_key(key)
458
+
459
+ return JSONResponse({"error": "All keys failed"}, status_code=500)
460
+
461
+ # -----------------------------------------
462
+ # STREAM
463
+ # -----------------------------------------
464
+ async def stream_gen():
465
+ tried = set()
466
+
467
+ for _ in range(len(API_KEYS)):
468
+ key = await wait_for_free_key(exclude=tried)
469
+ if not key:
470
+ break
471
+
472
+ tried.add(key)
473
+ ki = key_status[key]
474
+ log(f"STREAM: Using key#{ki['index']}")
475
+
476
+ try:
477
+ async with httpx.AsyncClient(timeout=None) as client:
478
+ async with client.stream(
479
+ "POST",
480
+ f"{CEREBRAS_BASE_URL}/chat/completions",
481
+ json=cerebras_body,
482
+ headers={
483
+ "Authorization": f"Bearer {key}",
484
+ "Content-Type": "application/json",
485
+ }
486
+ ) as r:
487
+
488
+ if is_rate_limited(r.status_code, ""):
489
+ log(f"STREAM RATE LIMITED: key#{ki['index']}, trying next")
490
+ await mark_fail(key)
491
+ continue
492
+
493
+ if r.status_code != 200:
494
+ log(f"STREAM HTTP {r.status_code}: key#{ki['index']}, trying next")
495
+ await mark_fail(key)
496
+ continue
497
+
498
+ hit_limit = False
499
+
500
+ async for line in r.aiter_lines():
501
+ if not line:
502
+ continue
503
+
504
+ if line.strip() == "data: [DONE]":
505
+ break
506
+
507
+ raw = line[6:] if line.startswith("data: ") else line
508
+ if is_rate_limited(0, raw):
509
+ log(f"MID-STREAM LIMIT: key#{ki['index']}, switching")
510
+ hit_limit = True
511
+ break
512
+
513
+ # Cerebras SSE is already OpenAI-compatible, pipe directly
514
+ yield line + "\n\n"
515
+
516
+ if hit_limit:
517
+ await mark_fail(key)
518
+ continue
519
+
520
+ yield "data: [DONE]\n\n"
521
+ await mark_ok(key)
522
+ return
523
+
524
+ except Exception as e:
525
+ log(f"STREAM EXCEPTION: key#{ki['index']} - {e}")
526
+ await mark_fail(key)
527
+
528
+ finally:
529
+ await release_key(key)
530
+
531
+ yield sse({"error": "All keys failed"})
532
+ yield "data: [DONE]\n\n"
533
+
534
+ return StreamingResponse(stream_gen(), media_type="text/event-stream")
535
+
536
+ # =====================================================
537
+ # /v1/messages (Anthropic-compatible)
538
+ # =====================================================
539
+ @app.post("/v1/messages")
540
+ async def anthropic_messages(req: Request):
541
+ if not auth_ok(req):
542
+ return JSONResponse(
543
+ {"type": "error", "error": {"type": "authentication_error", "message": "Unauthorized"}},
544
+ status_code=401
545
+ )
546
+
547
+ try:
548
+ body = await req.json()
549
+ except ClientDisconnect:
550
+ return Response(status_code=499)
551
+ except Exception:
552
+ return JSONResponse(
553
+ {"type": "error", "error": {"type": "invalid_request_error", "message": "Bad JSON"}},
554
+ status_code=400
555
+ )
556
+
557
+ is_stream = body.get("stream", False)
558
+ original_model = body.get("model", DEFAULT_MODEL)
559
+ cerebras_model = map_model(original_model)
560
+ max_tokens = body.get("max_tokens", 4096)
561
+
562
+ # Convert Anthropic messages -> OpenAI format
563
+ messages = []
564
+
565
+ if body.get("system"):
566
+ sys_content = body["system"]
567
+ if isinstance(sys_content, list):
568
+ # Anthropic system can be list of content blocks
569
+ txt = "".join(x.get("text", "") for x in sys_content if x.get("type") == "text")
570
+ sys_content = txt
571
+ messages.append({"role": "system", "content": sys_content})
572
+
573
+ for m in body.get("messages", []):
574
+ content = m.get("content", "")
575
+ if isinstance(content, list):
576
+ txt = ""
577
+ for block in content:
578
+ if block.get("type") == "text":
579
+ txt += block.get("text", "")
580
+ elif block.get("type") == "tool_result":
581
+ txt += block.get("content", str(block))
582
+ elif block.get("type") == "tool_use":
583
+ txt += json.dumps(block)
584
+ content = txt
585
+ messages.append({"role": m["role"], "content": content})
586
+
587
+ # Token limiting
588
+ messages = truncate_messages(messages, MAX_REQUEST_TOKENS)
589
+
590
+ cerebras_body = {
591
+ "model": cerebras_model,
592
+ "messages": messages,
593
+ "stream": is_stream,
594
+ "max_completion_tokens": min(max_tokens, 8192),
595
+ }
596
+
597
+ # Forward optional params
598
+ if "temperature" in body:
599
+ cerebras_body["temperature"] = body["temperature"]
600
+ if "top_p" in body:
601
+ cerebras_body["top_p"] = body["top_p"]
602
+
603
+ # -----------------------------------------
604
+ # NON STREAM
605
+ # -----------------------------------------
606
+ if not is_stream:
607
+ tried = set()
608
+
609
+ for _ in range(len(API_KEYS)):
610
+ key = await wait_for_free_key(exclude=tried)
611
+ if not key:
612
+ break
613
+
614
+ tried.add(key)
615
+ ki = key_status[key]
616
+ log(f"ANTHROPIC NON-STREAM: key#{ki['index']}")
617
+
618
+ try:
619
+ async with httpx.AsyncClient(timeout=180) as client:
620
+ r = await client.post(
621
+ f"{CEREBRAS_BASE_URL}/chat/completions",
622
+ json=cerebras_body,
623
+ headers={
624
+ "Authorization": f"Bearer {key}",
625
+ "Content-Type": "application/json",
626
+ }
627
+ )
628
+
629
+ if is_rate_limited(r.status_code, r.text):
630
+ log(f"RATE LIMITED: key#{ki['index']}")
631
+ await mark_fail(key)
632
+ continue
633
+
634
+ if r.status_code != 200:
635
+ log(f"HTTP {r.status_code}: key#{ki['index']}")
636
+ await mark_fail(key)
637
+ continue
638
+
639
+ data = r.json()
640
+
641
+ # Convert OpenAI response -> Anthropic format
642
+ content_text = data["choices"][0]["message"]["content"]
643
+ usage = data.get("usage", {})
644
+ finish = data["choices"][0].get("finish_reason", "stop")
645
+
646
+ stop_map = {"stop": "end_turn", "length": "max_tokens", "eos": "end_turn"}
647
+
648
+ out = {
649
+ "id": "msg_" + uuid.uuid4().hex[:10],
650
+ "type": "message",
651
+ "role": "assistant",
652
+ "model": original_model,
653
+ "content": [{"type": "text", "text": content_text}],
654
+ "stop_reason": stop_map.get(finish, "end_turn"),
655
+ "stop_sequence": None,
656
+ "usage": {
657
+ "input_tokens": usage.get("prompt_tokens", 0),
658
+ "output_tokens": usage.get("completion_tokens", 0),
659
+ }
660
+ }
661
+
662
+ await mark_ok(key)
663
+ return JSONResponse(out)
664
+
665
+ except Exception as e:
666
+ log(f"Exception: key#{ki['index']} - {e}")
667
+ await mark_fail(key)
668
+
669
+ finally:
670
+ await release_key(key)
671
+
672
+ return JSONResponse(
673
+ {"type": "error", "error": {"type": "api_error", "message": "All keys failed"}},
674
+ status_code=500
675
+ )
676
+
677
+ # -----------------------------------------
678
+ # STREAM (Anthropic SSE envelope)
679
+ # -----------------------------------------
680
+ async def anthropic_stream_gen():
681
+ tried = set()
682
+ msg_id = "msg_" + uuid.uuid4().hex[:10]
683
+ sent_header = False
684
+
685
+ for _ in range(len(API_KEYS)):
686
+ key = await wait_for_free_key(exclude=tried)
687
+ if not key:
688
+ break
689
+
690
+ tried.add(key)
691
+ ki = key_status[key]
692
+ log(f"ANTHROPIC STREAM: key#{ki['index']}")
693
+
694
+ try:
695
+ async with httpx.AsyncClient(timeout=None) as client:
696
+ async with client.stream(
697
+ "POST",
698
+ f"{CEREBRAS_BASE_URL}/chat/completions",
699
+ json=cerebras_body,
700
+ headers={
701
+ "Authorization": f"Bearer {key}",
702
+ "Content-Type": "application/json",
703
+ }
704
+ ) as r:
705
+
706
+ if is_rate_limited(r.status_code, ""):
707
+ log(f"STREAM RATE LIMITED: key#{ki['index']}")
708
+ await mark_fail(key)
709
+ continue
710
+
711
+ if r.status_code != 200:
712
+ log(f"STREAM HTTP {r.status_code}: key#{ki['index']}")
713
+ await mark_fail(key)
714
+ continue
715
+
716
+ # Send Anthropic envelope header (once)
717
+ if not sent_header:
718
+ yield sse({
719
+ "type": "message_start",
720
+ "message": {
721
+ "id": msg_id,
722
+ "type": "message",
723
+ "role": "assistant",
724
+ "model": original_model,
725
+ "content": [],
726
+ "stop_reason": None,
727
+ "stop_sequence": None,
728
+ "usage": {"input_tokens": 0, "output_tokens": 0}
729
+ }
730
+ })
731
+ yield sse({
732
+ "type": "content_block_start",
733
+ "index": 0,
734
+ "content_block": {"type": "text", "text": ""}
735
+ })
736
+ sent_header = True
737
+
738
+ hit_limit = False
739
+ output_tokens = 0
740
+
741
+ async for line in r.aiter_lines():
742
+ if not line:
743
+ continue
744
+ if line.strip() == "data: [DONE]":
745
+ break
746
+
747
+ raw = line[6:] if line.startswith("data: ") else line
748
+
749
+ if is_rate_limited(0, raw):
750
+ log(f"MID-STREAM LIMIT: key#{ki['index']}")
751
+ hit_limit = True
752
+ break
753
+
754
+ try:
755
+ j = json.loads(raw)
756
+ token = j["choices"][0]["delta"].get("content", "")
757
+ if j.get("usage"):
758
+ output_tokens = j["usage"].get("completion_tokens", output_tokens)
759
+ except Exception:
760
+ continue
761
+
762
+ if token:
763
+ yield sse({
764
+ "type": "content_block_delta",
765
+ "index": 0,
766
+ "delta": {"type": "text_delta", "text": token}
767
+ })
768
+
769
+ if hit_limit:
770
+ await mark_fail(key)
771
+ continue
772
+
773
+ await mark_ok(key)
774
+ break # success, exit retry loop
775
+
776
+ except Exception as e:
777
+ log(f"STREAM EXCEPTION: key#{ki['index']} - {e}")
778
+ await mark_fail(key)
779
+
780
+ finally:
781
+ await release_key(key)
782
+
783
+ # Close Anthropic SSE envelope
784
+ yield sse({"type": "content_block_stop", "index": 0})
785
+ yield sse({
786
+ "type": "message_delta",
787
+ "delta": {"stop_reason": "end_turn", "stop_sequence": None},
788
+ "usage": {"output_tokens": 0}
789
+ })
790
+ yield sse({"type": "message_stop"})
791
+
792
+ return StreamingResponse(anthropic_stream_gen(), media_type="text/event-stream")