Elysiadev11 commited on
Commit
3f83b4e
·
verified ·
1 Parent(s): 574c72e

Update proxy_cerebras.py

Browse files
Files changed (1) hide show
  1. proxy_cerebras.py +167 -143
proxy_cerebras.py CHANGED
@@ -41,12 +41,15 @@ for idx, k in enumerate(OLLAMA_KEYS, 1):
41
 
42
  rr_index = 0
43
 
 
 
 
44
 
45
  # =====================================================
46
  # HELPERS
47
  # =====================================================
48
  def log(x):
49
- print(f"[{time.strftime('%H:%M:%S')}] {x}")
50
 
51
 
52
  def sse(obj):
@@ -58,39 +61,61 @@ def auth_ok(req: Request):
58
  return token == MASTER_API_KEY
59
 
60
 
61
- def get_key(exclude=None):
 
 
 
 
62
  global rr_index
63
 
64
  if exclude is None:
65
  exclude = set()
66
 
67
- for _ in range(len(OLLAMA_KEYS)):
68
- rr_index = (rr_index + 1) % len(OLLAMA_KEYS)
69
- k = OLLAMA_KEYS[rr_index]
70
-
71
- st = key_status[k]
72
 
73
- if st["healthy"] and not st["busy"] and k not in exclude:
74
- st["busy"] = True
75
- return k
76
 
77
  return None
78
 
79
 
80
- def release_key(k):
81
- if k in key_status:
82
- key_status[k]["busy"] = False
 
83
 
84
 
85
- def mark_fail(k):
86
- if k in key_status:
87
- key_status[k]["fail"] += 1
 
88
 
89
 
90
- def mark_ok(k):
91
- if k in key_status:
92
- key_status[k]["success"] += 1
93
- key_status[k]["fail"] = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  # =====================================================
@@ -98,17 +123,17 @@ def mark_ok(k):
98
  # =====================================================
99
  @app.get("/")
100
  async def root():
101
- safe = {}
102
-
103
- for k, v in key_status.items():
104
- masked = k[:4] + "****" + k[-4:]
105
- safe[masked] = {
106
- "index": v["index"],
107
- "healthy": v["healthy"],
108
- "busy": v["busy"],
109
- "success": v["success"],
110
- "fail": v["fail"],
111
- }
112
 
113
  return {
114
  "status": "ok",
@@ -116,6 +141,7 @@ async def root():
116
  "detail": safe
117
  }
118
 
 
119
  # =====================================================
120
  # /v1/models
121
  # =====================================================
@@ -136,9 +162,8 @@ async def models(req: Request):
136
  return JSONResponse({"error": r.text}, status_code=r.status_code)
137
 
138
  data = r.json()
139
-
140
- out = []
141
  now = int(time.time())
 
142
 
143
  for m in data.get("models", []):
144
  out.append({
@@ -152,7 +177,7 @@ async def models(req: Request):
152
 
153
 
154
  # =====================================================
155
- # OPENAI CHAT
156
  # =====================================================
157
  @app.post("/v1/chat/completions")
158
  async def chat(req: Request):
@@ -161,7 +186,7 @@ async def chat(req: Request):
161
 
162
  try:
163
  body = await req.json()
164
- except:
165
  return JSONResponse({"error": "Bad JSON"}, status_code=400)
166
 
167
  is_stream = body.get("stream", False)
@@ -173,11 +198,10 @@ async def chat(req: Request):
173
  tried = set()
174
 
175
  for _ in range(len(OLLAMA_KEYS)):
176
- key = get_key(tried)
177
 
178
  if not key:
179
- await asyncio.sleep(0.3)
180
- continue
181
 
182
  tried.add(key)
183
 
@@ -192,25 +216,23 @@ async def chat(req: Request):
192
  txt = r.text.lower()
193
 
194
  if "weekly usage limit" in txt or r.status_code == 429:
195
- mark_fail(key)
 
196
  continue
197
 
198
- mark_ok(key)
199
 
200
  return Response(
201
  content=r.content,
202
- media_type=r.headers.get(
203
- "content-type",
204
- "application/json"
205
- )
206
  )
207
 
208
  except Exception as e:
209
- log(e)
210
- mark_fail(key)
211
 
212
  finally:
213
- release_key(key)
214
 
215
  return JSONResponse({"error": "All keys failed"}, status_code=500)
216
 
@@ -221,11 +243,10 @@ async def chat(req: Request):
221
  tried = set()
222
 
223
  for _ in range(len(OLLAMA_KEYS)):
224
- key = get_key(tried)
225
 
226
  if not key:
227
- await asyncio.sleep(0.3)
228
- continue
229
 
230
  tried.add(key)
231
 
@@ -239,22 +260,37 @@ async def chat(req: Request):
239
  ) as r:
240
 
241
  if r.status_code == 429:
242
- mark_fail(key)
 
243
  continue
244
 
 
 
245
  async for line in r.aiter_lines():
246
- if line:
247
- yield line + "\n\n"
 
 
 
 
 
 
248
 
249
- mark_ok(key)
 
 
 
 
 
 
250
  return
251
 
252
  except Exception as e:
253
- log(e)
254
- mark_fail(key)
255
 
256
  finally:
257
- release_key(key)
258
 
259
  yield sse({"error": "All keys failed"})
260
  yield "data: [DONE]\n\n"
@@ -274,16 +310,16 @@ async def anthropic(req: Request):
274
  body = await req.json()
275
  except ClientDisconnect:
276
  return Response(status_code=499)
 
 
277
 
278
  stream = body.get("stream", False)
279
 
 
280
  messages = []
281
 
282
  if body.get("system"):
283
- messages.append({
284
- "role": "system",
285
- "content": body["system"]
286
- })
287
 
288
  for m in body.get("messages", []):
289
  content = m.get("content", "")
@@ -295,10 +331,7 @@ async def anthropic(req: Request):
295
  txt += x.get("text", "")
296
  content = txt
297
 
298
- messages.append({
299
- "role": m["role"],
300
- "content": content
301
- })
302
 
303
  proxy_body = {
304
  "model": "minimax-m2.7:cloud",
@@ -313,11 +346,10 @@ async def anthropic(req: Request):
313
  tried = set()
314
 
315
  for _ in range(len(OLLAMA_KEYS)):
316
- key = get_key(tried)
317
 
318
  if not key:
319
- await asyncio.sleep(0.3)
320
- continue
321
 
322
  tried.add(key)
323
 
@@ -332,11 +364,11 @@ async def anthropic(req: Request):
332
  txt = r.text.lower()
333
 
334
  if "weekly usage limit" in txt or r.status_code == 429:
335
- mark_fail(key)
 
336
  continue
337
 
338
  data = r.json()
339
-
340
  ans = data["choices"][0]["message"]["content"]
341
 
342
  out = {
@@ -344,70 +376,42 @@ async def anthropic(req: Request):
344
  "type": "message",
345
  "role": "assistant",
346
  "model": body.get("model", "claude-opus-4-7"),
347
- "content": [
348
- {
349
- "type": "text",
350
- "text": ans
351
- }
352
- ],
353
  "stop_reason": "end_turn",
354
  "stop_sequence": None,
355
- "usage": {
356
- "input_tokens": 0,
357
- "output_tokens": 0
358
- }
359
  }
360
 
361
- mark_ok(key)
362
  return JSONResponse(out)
363
 
364
  except Exception as e:
365
- log(e)
366
- mark_fail(key)
367
 
368
  finally:
369
- release_key(key)
370
 
371
  return JSONResponse({"error": "All keys failed"}, status_code=500)
372
 
373
  # -----------------------------------------
374
- # STREAM
375
  # -----------------------------------------
376
  async def agen():
377
  tried = set()
378
  msg_id = "msg_" + uuid.uuid4().hex[:10]
 
379
 
380
- start_payload = {
381
- "type": "message_start",
382
- "message": {
383
- "id": msg_id,
384
- "type": "message",
385
- "role": "assistant",
386
- "model": body.get("model", "claude-opus-4-7"),
387
- "content": [],
388
- "stop_reason": None,
389
- "stop_sequence": None,
390
- "usage": {
391
- "input_tokens": 0,
392
- "output_tokens": 0
393
- }
394
- }
395
- }
396
-
397
- yield sse(start_payload)
398
-
399
- yield sse({
400
- "type": "content_block_start",
401
- "index": 0,
402
- "content_block": {"type": "text"}
403
- })
404
 
405
  for _ in range(len(OLLAMA_KEYS)):
406
- key = get_key(tried)
407
 
408
  if not key:
409
- await asyncio.sleep(0.3)
410
- continue
411
 
412
  tried.add(key)
413
 
@@ -421,9 +425,33 @@ async def anthropic(req: Request):
421
  ) as r:
422
 
423
  if r.status_code == 429:
424
- mark_fail(key)
 
425
  continue
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  async for line in r.aiter_lines():
428
  if not line.startswith("data: "):
429
  continue
@@ -433,55 +461,51 @@ async def anthropic(req: Request):
433
  if raw == "[DONE]":
434
  break
435
 
 
 
 
 
 
 
436
  try:
437
  j = json.loads(raw)
438
- except:
439
  continue
440
 
441
  delta = j["choices"][0]["delta"]
442
  txt = delta.get("content", "")
443
 
444
  if txt:
 
445
  yield sse({
446
  "type": "content_block_delta",
447
  "index": 0,
448
- "delta": {
449
- "type": "text_delta",
450
- "text": txt
451
- }
452
  })
453
 
454
- mark_ok(key)
455
- break
 
 
 
 
 
 
456
 
457
  except Exception as e:
458
- log(e)
459
- mark_fail(key)
460
 
461
  finally:
462
- release_key(key)
463
-
464
- yield sse({
465
- "type": "content_block_stop",
466
- "index": 0
467
- })
468
 
 
 
469
  yield sse({
470
  "type": "message_delta",
471
- "delta": {
472
- "stop_reason": "end_turn",
473
- "stop_sequence": None
474
- },
475
- "usage": {
476
- "output_tokens": 0
477
- }
478
- })
479
-
480
- yield sse({
481
- "type": "message_stop"
482
  })
 
483
 
484
- return StreamingResponse(
485
- agen(),
486
- media_type="text/event-stream"
487
- )
 
41
 
42
  rr_index = 0
43
 
44
+ # Global async lock to prevent race condition on rr_index & busy flag
45
+ _key_lock = asyncio.Lock()
46
+
47
 
48
  # =====================================================
49
  # HELPERS
50
  # =====================================================
51
  def log(x):
52
+ print(f"[{time.strftime('%H:%M:%S')}] {x}", flush=True)
53
 
54
 
55
  def sse(obj):
 
61
  return token == MASTER_API_KEY
62
 
63
 
64
+ async def get_key(exclude=None):
65
+ """
66
+ Thread-safe round-robin key picker.
67
+ Returns the key string, or None if all are busy/excluded.
68
+ """
69
  global rr_index
70
 
71
  if exclude is None:
72
  exclude = set()
73
 
74
+ async with _key_lock:
75
+ for _ in range(len(OLLAMA_KEYS)):
76
+ rr_index = (rr_index + 1) % len(OLLAMA_KEYS)
77
+ k = OLLAMA_KEYS[rr_index]
78
+ st = key_status[k]
79
 
80
+ if st["healthy"] and not st["busy"] and k not in exclude:
81
+ st["busy"] = True
82
+ return k
83
 
84
  return None
85
 
86
 
87
+ async def release_key(k):
88
+ async with _key_lock:
89
+ if k in key_status:
90
+ key_status[k]["busy"] = False
91
 
92
 
93
+ async def mark_fail(k):
94
+ async with _key_lock:
95
+ if k in key_status:
96
+ key_status[k]["fail"] += 1
97
 
98
 
99
+ async def mark_ok(k):
100
+ async with _key_lock:
101
+ if k in key_status:
102
+ key_status[k]["success"] += 1
103
+ key_status[k]["fail"] = 0
104
+
105
+
106
+ async def wait_for_free_key(exclude=None, max_wait=30.0, interval=0.3):
107
+ """
108
+ Polls until a free key is available or max_wait seconds pass.
109
+ Returns the key or None on timeout.
110
+ """
111
+ elapsed = 0.0
112
+ while elapsed < max_wait:
113
+ key = await get_key(exclude)
114
+ if key:
115
+ return key
116
+ await asyncio.sleep(interval)
117
+ elapsed += interval
118
+ return None
119
 
120
 
121
  # =====================================================
 
123
  # =====================================================
124
  @app.get("/")
125
  async def root():
126
+ async with _key_lock:
127
+ safe = {}
128
+ for k, v in key_status.items():
129
+ masked = k[:4] + "****" + k[-4:]
130
+ safe[masked] = {
131
+ "index": v["index"],
132
+ "healthy": v["healthy"],
133
+ "busy": v["busy"],
134
+ "success": v["success"],
135
+ "fail": v["fail"],
136
+ }
137
 
138
  return {
139
  "status": "ok",
 
141
  "detail": safe
142
  }
143
 
144
+
145
  # =====================================================
146
  # /v1/models
147
  # =====================================================
 
162
  return JSONResponse({"error": r.text}, status_code=r.status_code)
163
 
164
  data = r.json()
 
 
165
  now = int(time.time())
166
+ out = []
167
 
168
  for m in data.get("models", []):
169
  out.append({
 
177
 
178
 
179
  # =====================================================
180
+ # OPENAI CHAT /v1/chat/completions
181
  # =====================================================
182
  @app.post("/v1/chat/completions")
183
  async def chat(req: Request):
 
186
 
187
  try:
188
  body = await req.json()
189
+ except Exception:
190
  return JSONResponse({"error": "Bad JSON"}, status_code=400)
191
 
192
  is_stream = body.get("stream", False)
 
198
  tried = set()
199
 
200
  for _ in range(len(OLLAMA_KEYS)):
201
+ key = await wait_for_free_key(exclude=tried)
202
 
203
  if not key:
204
+ break
 
205
 
206
  tried.add(key)
207
 
 
216
  txt = r.text.lower()
217
 
218
  if "weekly usage limit" in txt or r.status_code == 429:
219
+ log(f"Key {key[:8]}... rate limited (non-stream chat), trying next")
220
+ await mark_fail(key)
221
  continue
222
 
223
+ await mark_ok(key)
224
 
225
  return Response(
226
  content=r.content,
227
+ media_type=r.headers.get("content-type", "application/json")
 
 
 
228
  )
229
 
230
  except Exception as e:
231
+ log(f"Key {key[:8]}... exception: {e}")
232
+ await mark_fail(key)
233
 
234
  finally:
235
+ await release_key(key)
236
 
237
  return JSONResponse({"error": "All keys failed"}, status_code=500)
238
 
 
243
  tried = set()
244
 
245
  for _ in range(len(OLLAMA_KEYS)):
246
+ key = await wait_for_free_key(exclude=tried)
247
 
248
  if not key:
249
+ break
 
250
 
251
  tried.add(key)
252
 
 
260
  ) as r:
261
 
262
  if r.status_code == 429:
263
+ log(f"Key {key[:8]}... rate limited (stream chat), trying next")
264
+ await mark_fail(key)
265
  continue
266
 
267
+ hit_limit_mid_stream = False
268
+
269
  async for line in r.aiter_lines():
270
+ if not line:
271
+ continue
272
+
273
+ # Detect mid-stream rate limit signal in data payload
274
+ if "429" in line or "usage limit" in line.lower():
275
+ log(f"Key {key[:8]}... mid-stream limit detected, aborting chunk")
276
+ hit_limit_mid_stream = True
277
+ break
278
 
279
+ yield line + "\n\n"
280
+
281
+ if hit_limit_mid_stream:
282
+ await mark_fail(key)
283
+ continue
284
+
285
+ await mark_ok(key)
286
  return
287
 
288
  except Exception as e:
289
+ log(f"Key {key[:8]}... stream exception: {e}")
290
+ await mark_fail(key)
291
 
292
  finally:
293
+ await release_key(key)
294
 
295
  yield sse({"error": "All keys failed"})
296
  yield "data: [DONE]\n\n"
 
310
  body = await req.json()
311
  except ClientDisconnect:
312
  return Response(status_code=499)
313
+ except Exception:
314
+ return JSONResponse({"error": "Bad JSON"}, status_code=400)
315
 
316
  stream = body.get("stream", False)
317
 
318
+ # Build messages list for proxy
319
  messages = []
320
 
321
  if body.get("system"):
322
+ messages.append({"role": "system", "content": body["system"]})
 
 
 
323
 
324
  for m in body.get("messages", []):
325
  content = m.get("content", "")
 
331
  txt += x.get("text", "")
332
  content = txt
333
 
334
+ messages.append({"role": m["role"], "content": content})
 
 
 
335
 
336
  proxy_body = {
337
  "model": "minimax-m2.7:cloud",
 
346
  tried = set()
347
 
348
  for _ in range(len(OLLAMA_KEYS)):
349
+ key = await wait_for_free_key(exclude=tried)
350
 
351
  if not key:
352
+ break
 
353
 
354
  tried.add(key)
355
 
 
364
  txt = r.text.lower()
365
 
366
  if "weekly usage limit" in txt or r.status_code == 429:
367
+ log(f"Key {key[:8]}... rate limited (non-stream anthropic), trying next")
368
+ await mark_fail(key)
369
  continue
370
 
371
  data = r.json()
 
372
  ans = data["choices"][0]["message"]["content"]
373
 
374
  out = {
 
376
  "type": "message",
377
  "role": "assistant",
378
  "model": body.get("model", "claude-opus-4-7"),
379
+ "content": [{"type": "text", "text": ans}],
 
 
 
 
 
380
  "stop_reason": "end_turn",
381
  "stop_sequence": None,
382
+ "usage": {"input_tokens": 0, "output_tokens": 0}
 
 
 
383
  }
384
 
385
+ await mark_ok(key)
386
  return JSONResponse(out)
387
 
388
  except Exception as e:
389
+ log(f"Key {key[:8]}... exception: {e}")
390
+ await mark_fail(key)
391
 
392
  finally:
393
+ await release_key(key)
394
 
395
  return JSONResponse({"error": "All keys failed"}, status_code=500)
396
 
397
  # -----------------------------------------
398
+ # STREAM (Anthropic SSE format)
399
  # -----------------------------------------
400
  async def agen():
401
  tried = set()
402
  msg_id = "msg_" + uuid.uuid4().hex[:10]
403
+ sent_any_delta = False
404
 
405
+ # Send Anthropic envelope headers ONCE before first key attempt
406
+ # We defer these until we have a successful connection to avoid
407
+ # sending headers before knowing if any key works.
408
+ # Instead we buffer and yield only on confirmed success.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  for _ in range(len(OLLAMA_KEYS)):
411
+ key = await wait_for_free_key(exclude=tried)
412
 
413
  if not key:
414
+ break
 
415
 
416
  tried.add(key)
417
 
 
425
  ) as r:
426
 
427
  if r.status_code == 429:
428
+ log(f"Key {key[:8]}... rate limited (stream anthropic), trying next")
429
+ await mark_fail(key)
430
  continue
431
 
432
+ # Only emit Anthropic envelope on first successful key
433
+ if not sent_any_delta:
434
+ yield sse({
435
+ "type": "message_start",
436
+ "message": {
437
+ "id": msg_id,
438
+ "type": "message",
439
+ "role": "assistant",
440
+ "model": body.get("model", "claude-opus-4-7"),
441
+ "content": [],
442
+ "stop_reason": None,
443
+ "stop_sequence": None,
444
+ "usage": {"input_tokens": 0, "output_tokens": 0}
445
+ }
446
+ })
447
+ yield sse({
448
+ "type": "content_block_start",
449
+ "index": 0,
450
+ "content_block": {"type": "text"}
451
+ })
452
+
453
+ hit_limit_mid_stream = False
454
+
455
  async for line in r.aiter_lines():
456
  if not line.startswith("data: "):
457
  continue
 
461
  if raw == "[DONE]":
462
  break
463
 
464
+ # Detect mid-stream 429 / limit payload
465
+ if "429" in raw or "usage limit" in raw.lower():
466
+ log(f"Key {key[:8]}... mid-stream limit in anthropic, aborting chunk")
467
+ hit_limit_mid_stream = True
468
+ break
469
+
470
  try:
471
  j = json.loads(raw)
472
+ except Exception:
473
  continue
474
 
475
  delta = j["choices"][0]["delta"]
476
  txt = delta.get("content", "")
477
 
478
  if txt:
479
+ sent_any_delta = True
480
  yield sse({
481
  "type": "content_block_delta",
482
  "index": 0,
483
+ "delta": {"type": "text_delta", "text": txt}
 
 
 
484
  })
485
 
486
+ if hit_limit_mid_stream:
487
+ await mark_fail(key)
488
+ # Continue to next key — stream resumes from where it broke
489
+ # Note: client will receive continued deltas seamlessly
490
+ continue
491
+
492
+ await mark_ok(key)
493
+ break # Success — exit key retry loop
494
 
495
  except Exception as e:
496
+ log(f"Key {key[:8]}... agen exception: {e}")
497
+ await mark_fail(key)
498
 
499
  finally:
500
+ await release_key(key)
 
 
 
 
 
501
 
502
+ # Close Anthropic SSE envelope
503
+ yield sse({"type": "content_block_stop", "index": 0})
504
  yield sse({
505
  "type": "message_delta",
506
+ "delta": {"stop_reason": "end_turn", "stop_sequence": None},
507
+ "usage": {"output_tokens": 0}
 
 
 
 
 
 
 
 
 
508
  })
509
+ yield sse({"type": "message_stop"})
510
 
511
+ return StreamingResponse(agen(), media_type="text/event-stream")