Spaces:
Sleeping
Sleeping
| """ | |
| Cerebras Proxy Server | |
| - OpenAI-compatible endpoint: /v1/chat/completions | |
| - Anthropic-compatible endpoint: /v1/messages | |
| - Token limiting: auto-truncate oldest messages | |
| - Multi-key round-robin dengan infinite looping (tidak pernah stop) | |
| - FIXED: Tool calling support (Anthropic <-> OpenAI conversion) | |
| """ | |
| import os | |
| import json | |
| import time | |
| import uuid | |
| import asyncio | |
| import httpx | |
| import tiktoken | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse, Response, StreamingResponse | |
| from starlette.requests import ClientDisconnect | |
| app = FastAPI() | |
| # ===================================================== | |
| # CONFIG | |
| # ===================================================== | |
| MASTER_API_KEY = os.getenv("MASTER_API_KEY", "olla") | |
| CEREBRAS_BASE_URL = os.getenv("CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1") | |
| MAX_REQUEST_TOKENS = int(os.getenv("MAX_REQUEST_TOKENS", "20000")) | |
| DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "llama-4-scout-17b-16e-instruct") | |
| DEFAULT_MODEL_MAPPING = { | |
| "claude-opus-4-7": "llama-4-scout-17b-16e-instruct", | |
| "claude-opus-4-6": "llama-4-scout-17b-16e-instruct", | |
| "claude-opus-4-5": "llama-4-scout-17b-16e-instruct", | |
| "claude-opus-4-1": "llama-4-scout-17b-16e-instruct", | |
| "claude-opus-4-20250514": "llama-4-scout-17b-16e-instruct", | |
| "claude-sonnet-4-6": "llama-4-scout-17b-16e-instruct", | |
| "claude-sonnet-4-5": "llama-4-scout-17b-16e-instruct", | |
| "claude-sonnet-4-20250514": "llama-4-scout-17b-16e-instruct", | |
| "claude-haiku-4-5": "llama-4-scout-17b-16e-instruct", | |
| "claude-haiku-4-5-20251001": "llama-4-scout-17b-16e-instruct", | |
| "gpt-4": "llama-4-scout-17b-16e-instruct", | |
| "gpt-4o": "llama-4-scout-17b-16e-instruct", | |
| "gpt-4o-mini": "llama-4-scout-17b-16e-instruct", | |
| "gpt-4-turbo": "llama-4-scout-17b-16e-instruct", | |
| "gpt-3.5-turbo": "llama-4-scout-17b-16e-instruct", | |
| } | |
| def load_model_mapping(): | |
| mapping = DEFAULT_MODEL_MAPPING.copy() | |
| env_map = os.getenv("MODEL_MAP") | |
| if env_map: | |
| for pair in env_map.split(","): | |
| if ":" in pair: | |
| parts = pair.split(":", 1) | |
| if len(parts) == 2: | |
| mapping[parts[0].strip()] = parts[1].strip() | |
| return mapping | |
| def map_model(model_name: str) -> str: | |
| mapping = load_model_mapping() | |
| return mapping.get(model_name, model_name) | |
| # ===================================================== | |
| # API KEYS | |
| # ===================================================== | |
| API_KEYS = [] | |
| for i in range(1, 101): | |
| key = os.getenv(f"CEREBRAS_KEY_{i}") | |
| if key: | |
| API_KEYS.append(key) | |
| if not API_KEYS: | |
| fallback = os.getenv("CEREBRAS_API_KEY", "") | |
| API_KEYS.append(fallback if fallback else "dummy_key") | |
| # ===================================================== | |
| # KEY STATUS & ROUND ROBIN | |
| # ===================================================== | |
| RATE_LIMIT_COOLDOWN = int(os.getenv("RATE_LIMIT_COOLDOWN", "62")) # detik cooldown setelah rate limit | |
| key_status = {} | |
| for idx, k in enumerate(API_KEYS, 1): | |
| key_status[k] = { | |
| "index": idx, | |
| "prefix": k[:8] + "..." if len(k) > 8 else k, | |
| "busy": False, | |
| "success": 0, | |
| "fail": 0, | |
| "rate_limited_until": 0.0, # timestamp epoch; 0 = tidak sedang cooldown | |
| } | |
| rr_index = 0 | |
| _key_lock = asyncio.Lock() | |
| # ===================================================== | |
| # TOKEN COUNTING | |
| # ===================================================== | |
| try: | |
| _encoder = tiktoken.get_encoding("cl100k_base") | |
| except Exception: | |
| _encoder = None | |
| def count_tokens(text: str) -> int: | |
| if _encoder is None: | |
| return len(text) | |
| return len(_encoder.encode(text, disallowed_special=())) | |
| def count_messages_tokens(messages: list) -> int: | |
| total = 0 | |
| for msg in messages: | |
| content = msg.get("content", "") | |
| if isinstance(content, list): | |
| for block in content: | |
| if isinstance(block, dict): | |
| if block.get("type") == "text": | |
| total += count_tokens(block.get("text", "")) | |
| elif block.get("type") == "image_url": | |
| total += 1500 | |
| elif isinstance(content, str): | |
| total += count_tokens(content) | |
| total += 4 | |
| return total | |
| def truncate_messages(messages: list, max_tokens: int) -> list: | |
| if not messages: | |
| return messages | |
| total = count_messages_tokens(messages) | |
| safety_limit = max(1000, max_tokens - 2000) | |
| if total <= safety_limit: | |
| return messages | |
| log(f"⚠️ Token count {total} exceeds safety limit {safety_limit}. Truncating...") | |
| initial_total = total | |
| system_msgs = [m for m in messages if m.get("role") == "system"] | |
| other_msgs = [m for m in messages if m.get("role") != "system"] | |
| if not other_msgs: | |
| return messages | |
| last_msg = other_msgs[-1] | |
| middle_msgs = other_msgs[:-1] | |
| remaining_budget = safety_limit - count_messages_tokens(system_msgs) - count_messages_tokens([last_msg]) | |
| if remaining_budget < 0: | |
| if system_msgs: | |
| sys_content = system_msgs[0].get("content", "") | |
| if isinstance(sys_content, str): | |
| max_sys = min(2000, max_tokens // 4) | |
| if _encoder: | |
| tokens = _encoder.encode(sys_content, disallowed_special=()) | |
| if len(tokens) > max_sys: | |
| sys_content = _encoder.decode(tokens[:max_sys]) | |
| else: | |
| sys_content = sys_content[:max_sys * 4] | |
| system_msgs[0] = {**system_msgs[0], "content": sys_content} | |
| last_content = last_msg.get("content", "") | |
| if isinstance(last_content, str): | |
| max_last = max_tokens - count_messages_tokens(system_msgs) - 10 | |
| if max_last > 0 and count_tokens(last_content) > max_last: | |
| if _encoder: | |
| tokens = _encoder.encode(last_content, disallowed_special=()) | |
| last_content = _encoder.decode(tokens[:max_last]) | |
| else: | |
| last_content = last_content[:max_last * 4] | |
| last_msg = {**last_msg, "content": last_content} | |
| return system_msgs + [last_msg] | |
| kept_middle = [] | |
| for msg in reversed(middle_msgs): | |
| msg_tokens = count_messages_tokens([msg]) | |
| if remaining_budget >= msg_tokens: | |
| kept_middle.insert(0, msg) | |
| remaining_budget -= msg_tokens | |
| elif remaining_budget > 50: | |
| content = msg.get("content", "") | |
| if isinstance(content, str) and remaining_budget > 10: | |
| if _encoder: | |
| tokens = _encoder.encode(content, disallowed_special=()) | |
| truncated = _encoder.decode(tokens[:remaining_budget - 10]) | |
| else: | |
| truncated = content[:(remaining_budget - 10) * 4] | |
| kept_middle.insert(0, {**msg, "content": truncated + "\n[...truncated]"}) | |
| remaining_budget = 0 | |
| break | |
| result = system_msgs + kept_middle + [last_msg] | |
| log(f"✂️ TRUNCATE: {initial_total} -> {count_messages_tokens(result)} tokens") | |
| return result | |
| # ===================================================== | |
| # UTILITY | |
| # ===================================================== | |
| def log(msg): | |
| print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) | |
| def sse(obj): | |
| return "data: " + json.dumps(obj, ensure_ascii=False) + "\n\n" | |
| def auth_ok(req: Request): | |
| token = req.headers.get("Authorization", "").replace("Bearer ", "") | |
| return token == MASTER_API_KEY | |
| def is_rate_limited_status(status_code: int) -> bool: | |
| return status_code == 429 | |
| def is_rate_limited_error_body(text: str) -> bool: | |
| """ | |
| Cek rate limit dari HTTP error response body. | |
| JANGAN pakai ini untuk mengecek model output token! | |
| """ | |
| t = text.lower() | |
| return "rate limit" in t or "too many requests" in t or "usage limit" in t | |
| # ===================================================== | |
| # KEY MANAGEMENT | |
| # ===================================================== | |
| def _get_available_key(exclude: set) -> str | None: | |
| """ | |
| Internal (sync, dipanggil dalam _key_lock): cari key yang: | |
| 1. Tidak sedang busy | |
| 2. Tidak sedang cooldown rate limit | |
| 3. Tidak ada di exclude set | |
| Round-robin. | |
| """ | |
| global rr_index | |
| now = time.time() | |
| for _ in range(len(API_KEYS)): | |
| rr_index = (rr_index + 1) % len(API_KEYS) | |
| key = API_KEYS[rr_index] | |
| st = key_status[key] | |
| if not st["busy"] and now >= st["rate_limited_until"] and key not in exclude: | |
| st["busy"] = True | |
| return key | |
| return None | |
| def _next_available_time() -> float: | |
| """Kapan key pertama keluar dari cooldown (epoch seconds).""" | |
| now = time.time() | |
| times = [st["rate_limited_until"] for st in key_status.values() if st["rate_limited_until"] > now] | |
| return min(times) if times else now | |
| async def get_key(exclude=None): | |
| if exclude is None: | |
| exclude = set() | |
| async with _key_lock: | |
| return _get_available_key(exclude) | |
| async def release_key(key): | |
| async with _key_lock: | |
| if key in key_status: | |
| key_status[key]["busy"] = False | |
| async def mark_rate_limited(key): | |
| """Tandai key kena rate limit: set cooldown RATE_LIMIT_COOLDOWN detik.""" | |
| async with _key_lock: | |
| if key in key_status: | |
| until = time.time() + RATE_LIMIT_COOLDOWN | |
| key_status[key]["rate_limited_until"] = until | |
| key_status[key]["fail"] += 1 | |
| idx = key_status[key]["index"] | |
| log(f"⏳ key#{idx} cooldown {RATE_LIMIT_COOLDOWN}s (sampai {time.strftime('%H:%M:%S', time.localtime(until))})") | |
| async def mark_fail(key): | |
| async with _key_lock: | |
| if key in key_status: | |
| key_status[key]["fail"] += 1 | |
| async def mark_ok(key): | |
| async with _key_lock: | |
| if key in key_status: | |
| key_status[key]["success"] += 1 | |
| key_status[key]["fail"] = 0 | |
| key_status[key]["rate_limited_until"] = 0.0 | |
| async def wait_for_free_key(exclude=None, max_wait=120.0, interval=0.5): | |
| """Tunggu key tersedia, max max_wait detik.""" | |
| elapsed = 0.0 | |
| while elapsed < max_wait: | |
| key = await get_key(exclude) | |
| if key: | |
| return key | |
| await asyncio.sleep(interval) | |
| elapsed += interval | |
| return None | |
| async def get_key_infinite(exclude=None): | |
| """ | |
| Tunggu key tanpa batas waktu (infinite). | |
| - Kalau ada key tersedia: return langsung. | |
| - Kalau semua key busy/cooldown: sleep TEPAT sampai key paling cepat ready, | |
| lalu retry — tidak perlu hammering setiap 2 detik. | |
| - exclude di-reset setiap full cycle supaya key bisa dipakai lagi. | |
| """ | |
| local_exclude = set(exclude) if exclude else set() | |
| cycle = 0 | |
| while True: | |
| async with _key_lock: | |
| key = _get_available_key(local_exclude) | |
| if key: | |
| return key, local_exclude | |
| # Hitung berapa lama sampai key berikutnya ready | |
| now = time.time() | |
| next_ready = _next_available_time() | |
| wait_sec = max(0.5, next_ready - now) | |
| all_in_cooldown = all( | |
| st["rate_limited_until"] > now or st["busy"] | |
| for st in key_status.values() | |
| ) | |
| if all_in_cooldown: | |
| cycle += 1 | |
| log(f"⏳ Semua key cooldown. Tunggu {wait_sec:.1f}s sampai key berikutnya ready... (cycle #{cycle})") | |
| local_exclude.clear() # reset exclude agar key dicoba lagi setelah cooldown | |
| await asyncio.sleep(wait_sec) | |
| else: | |
| # Ada key yang sudah lewat cooldown tapi mungkin busy — tunggu sebentar | |
| await asyncio.sleep(0.3) | |
| # ===================================================== | |
| # TOOL CONVERSION: Anthropic ↔ OpenAI | |
| # ===================================================== | |
| def anthropic_tools_to_openai(anthropic_tools: list) -> list: | |
| """ | |
| Convert Anthropic tools format → OpenAI tools format. | |
| Anthropic: | |
| {"name": "fn", "description": "...", "input_schema": {...}} | |
| OpenAI: | |
| {"type": "function", "function": {"name": "fn", "description": "...", "parameters": {...}}} | |
| """ | |
| openai_tools = [] | |
| for t in anthropic_tools: | |
| openai_tools.append({ | |
| "type": "function", | |
| "function": { | |
| "name": t.get("name", ""), | |
| "description": t.get("description", ""), | |
| "parameters": t.get("input_schema", {"type": "object", "properties": {}}), | |
| } | |
| }) | |
| return openai_tools | |
| def anthropic_tool_choice_to_openai(tool_choice) -> str | dict | None: | |
| """Convert Anthropic tool_choice → OpenAI tool_choice.""" | |
| if tool_choice is None: | |
| return None | |
| if isinstance(tool_choice, str): | |
| mapping = {"auto": "auto", "any": "required", "none": "none"} | |
| return mapping.get(tool_choice, "auto") | |
| if isinstance(tool_choice, dict): | |
| tc_type = tool_choice.get("type", "") | |
| if tc_type == "tool": | |
| return {"type": "function", "function": {"name": tool_choice.get("name", "")}} | |
| mapping = {"auto": "auto", "any": "required", "none": "none"} | |
| return mapping.get(tc_type, "auto") | |
| return "auto" | |
| def convert_anthropic_messages_to_openai(anthropic_messages: list) -> list: | |
| """ | |
| Convert Anthropic messages → OpenAI messages. | |
| Handles: text, tool_use (assistant), tool_result (user). | |
| """ | |
| openai_messages = [] | |
| for m in anthropic_messages: | |
| role = m.get("role", "user") | |
| content = m.get("content", "") | |
| if isinstance(content, str): | |
| openai_messages.append({"role": role, "content": content}) | |
| continue | |
| # content is a list of blocks | |
| if not isinstance(content, list): | |
| openai_messages.append({"role": role, "content": str(content)}) | |
| continue | |
| # Check kalau ada tool_use blocks (assistant calling tools) | |
| tool_use_blocks = [b for b in content if b.get("type") == "tool_use"] | |
| text_blocks = [b for b in content if b.get("type") == "text"] | |
| if tool_use_blocks and role == "assistant": | |
| # Convert to OpenAI assistant message with tool_calls | |
| text_content = "".join(b.get("text", "") for b in text_blocks) or None | |
| tool_calls = [] | |
| for b in tool_use_blocks: | |
| tool_calls.append({ | |
| "id": b.get("id", "call_" + uuid.uuid4().hex[:8]), | |
| "type": "function", | |
| "function": { | |
| "name": b.get("name", ""), | |
| "arguments": json.dumps(b.get("input", {})) | |
| } | |
| }) | |
| msg = {"role": "assistant", "content": text_content, "tool_calls": tool_calls} | |
| openai_messages.append(msg) | |
| continue | |
| # Check kalau ada tool_result blocks (user returning tool results) | |
| tool_result_blocks = [b for b in content if b.get("type") == "tool_result"] | |
| if tool_result_blocks and role == "user": | |
| # Convert each tool_result → separate "tool" role message | |
| for b in tool_result_blocks: | |
| result_content = b.get("content", "") | |
| if isinstance(result_content, list): | |
| result_content = "".join( | |
| x.get("text", "") if isinstance(x, dict) else str(x) | |
| for x in result_content | |
| ) | |
| openai_messages.append({ | |
| "role": "tool", | |
| "tool_call_id": b.get("tool_use_id", ""), | |
| "content": str(result_content), | |
| }) | |
| # Kalau ada text blocks juga, tambahkan sebagai user message | |
| if text_blocks: | |
| txt = "".join(b.get("text", "") for b in text_blocks) | |
| if txt: | |
| openai_messages.append({"role": "user", "content": txt}) | |
| continue | |
| # Default: gabungkan semua text blocks | |
| txt = "".join(b.get("text", "") for b in text_blocks) | |
| openai_messages.append({"role": role, "content": txt}) | |
| return openai_messages | |
| def openai_response_to_anthropic(data: dict, original_model: str) -> dict: | |
| """ | |
| Convert OpenAI non-stream response → Anthropic response format. | |
| Handles both text response and tool_calls. | |
| """ | |
| choice = data["choices"][0] | |
| message = choice.get("message", {}) | |
| finish_reason = choice.get("finish_reason", "stop") | |
| usage = data.get("usage", {}) | |
| stop_map = { | |
| "stop": "end_turn", | |
| "length": "max_tokens", | |
| "eos": "end_turn", | |
| "tool_calls": "tool_use", | |
| } | |
| stop_reason = stop_map.get(finish_reason, "end_turn") | |
| content_blocks = [] | |
| # Text content | |
| text_content = message.get("content") or "" | |
| if text_content: | |
| content_blocks.append({"type": "text", "text": text_content}) | |
| # Tool calls → convert ke Anthropic tool_use blocks | |
| tool_calls = message.get("tool_calls") or [] | |
| for tc in tool_calls: | |
| fn = tc.get("function", {}) | |
| try: | |
| input_data = json.loads(fn.get("arguments", "{}")) | |
| except json.JSONDecodeError: | |
| input_data = {"_raw": fn.get("arguments", "")} | |
| content_blocks.append({ | |
| "type": "tool_use", | |
| "id": tc.get("id", "toolu_" + uuid.uuid4().hex[:10]), | |
| "name": fn.get("name", ""), | |
| "input": input_data, | |
| }) | |
| return { | |
| "id": "msg_" + uuid.uuid4().hex[:10], | |
| "type": "message", | |
| "role": "assistant", | |
| "model": original_model, | |
| "content": content_blocks, | |
| "stop_reason": stop_reason, | |
| "stop_sequence": None, | |
| "usage": { | |
| "input_tokens": usage.get("prompt_tokens", 0), | |
| "output_tokens": usage.get("completion_tokens", 0), | |
| } | |
| } | |
| # ===================================================== | |
| # ROOT / STATUS | |
| # ===================================================== | |
| async def root(): | |
| async with _key_lock: | |
| now = time.time() | |
| keys_info = {} | |
| for k, v in key_status.items(): | |
| rl_until = v["rate_limited_until"] | |
| cooldown_remaining = max(0, rl_until - now) | |
| keys_info[v["prefix"]] = { | |
| "status": "BUSY" if v["busy"] else ("COOLDOWN" if cooldown_remaining > 0 else "IDLE"), | |
| "cooldown_remaining_sec": round(cooldown_remaining, 1) if cooldown_remaining > 0 else 0, | |
| "success": v["success"], | |
| "fail": v["fail"], | |
| } | |
| return { | |
| "status": "ok", | |
| "backend": "cerebras", | |
| "base_url": CEREBRAS_BASE_URL, | |
| "default_model": DEFAULT_MODEL, | |
| "max_request_tokens": MAX_REQUEST_TOKENS, | |
| "rate_limit_cooldown_sec": RATE_LIMIT_COOLDOWN, | |
| "total_keys": len(API_KEYS), | |
| "keys": keys_info, | |
| } | |
| # ===================================================== | |
| # /v1/models | |
| # ===================================================== | |
| async def list_models(req: Request): | |
| if not auth_ok(req): | |
| return JSONResponse({"error": "Unauthorized"}, status_code=401) | |
| key = API_KEYS[0] if API_KEYS else "" | |
| try: | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| r = await client.get( | |
| f"{CEREBRAS_BASE_URL}/models", | |
| headers={"Authorization": f"Bearer {key}"} | |
| ) | |
| if r.status_code == 200: | |
| return Response(content=r.content, media_type="application/json") | |
| except Exception as e: | |
| log(f"[/v1/models] Error: {e}") | |
| now = int(time.time()) | |
| known_models = [ | |
| "llama-4-scout-17b-16e-instruct", | |
| "llama-4-maverick-17b-128e-instruct", | |
| "llama3.3-70b", | |
| "llama3.1-8b", | |
| "qwen-3-32b", | |
| "deepseek-r1-distill-llama-70b", | |
| ] | |
| data = [{"id": m, "object": "model", "created": now, "owned_by": "cerebras"} for m in known_models] | |
| return {"object": "list", "data": data} | |
| # ===================================================== | |
| # /v1/chat/completions (OpenAI-compatible) | |
| # ===================================================== | |
| async def chat(req: Request): | |
| if not auth_ok(req): | |
| return JSONResponse({"error": "Unauthorized"}, status_code=401) | |
| try: | |
| body = await req.json() | |
| except ClientDisconnect: | |
| return Response(status_code=499) | |
| except json.JSONDecodeError: | |
| return JSONResponse({"error": "Invalid JSON body"}, status_code=400) | |
| is_stream = body.get("stream", False) | |
| original_model = body.get("model", DEFAULT_MODEL) | |
| cerebras_model = map_model(original_model) | |
| messages = truncate_messages(body.get("messages", []), MAX_REQUEST_TOKENS) | |
| log(f"[OAI] Model: {cerebras_model}, Tokens: {count_messages_tokens(messages)}") | |
| cerebras_body = { | |
| "model": cerebras_model, | |
| "messages": messages, | |
| "stream": is_stream, | |
| } | |
| forward_params = [ | |
| "max_tokens", "max_completion_tokens", "temperature", "top_p", "stop", | |
| "frequency_penalty", "presence_penalty", "tools", "tool_choice", | |
| "parallel_tool_calls", "response_format" | |
| ] | |
| for param in forward_params: | |
| if param in body: | |
| cerebras_body[param] = body[param] | |
| if "max_tokens" not in cerebras_body and "max_completion_tokens" not in cerebras_body: | |
| cerebras_body["max_completion_tokens"] = 8192 | |
| # ----------------------------------------- | |
| # NON STREAM | |
| # ----------------------------------------- | |
| if not is_stream: | |
| tried = set() | |
| for _ in range(len(API_KEYS)): | |
| key = await wait_for_free_key(exclude=tried) | |
| if not key: | |
| break | |
| tried.add(key) | |
| ki = key_status[key] | |
| log(f"NON-STREAM: key#{ki['index']}") | |
| try: | |
| async with httpx.AsyncClient(timeout=180) as client: | |
| r = await client.post( | |
| f"{CEREBRAS_BASE_URL}/chat/completions", | |
| json=cerebras_body, | |
| headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| ) | |
| if is_rate_limited_status(r.status_code) or (r.status_code != 200 and is_rate_limited_error_body(r.text)): | |
| log(f"RATE LIMITED: key#{ki['index']}") | |
| await mark_rate_limited(key) | |
| continue | |
| if r.status_code != 200: | |
| log(f"HTTP {r.status_code}: key#{ki['index']}") | |
| await mark_fail(key) | |
| continue | |
| await mark_ok(key) | |
| return Response(content=r.content, media_type="application/json") | |
| except Exception as e: | |
| log(f"Exception: key#{ki['index']} - {e}") | |
| await mark_fail(key) | |
| finally: | |
| await release_key(key) | |
| return JSONResponse({"error": "All keys failed"}, status_code=500) | |
| # ----------------------------------------- | |
| # STREAM — infinite loop, tidak pernah stop | |
| # Ketika semua key cooldown, sleep TEPAT sampai key siap | |
| # ----------------------------------------- | |
| async def stream_gen(): | |
| exclude = set() | |
| while True: | |
| key, exclude = await get_key_infinite(exclude=exclude) | |
| ki = key_status[key] | |
| log(f"STREAM: key#{ki['index']}") | |
| try: | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{CEREBRAS_BASE_URL}/chat/completions", | |
| json=cerebras_body, | |
| headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| ) as r: | |
| if is_rate_limited_status(r.status_code): | |
| log(f"STREAM RATE LIMITED: key#{ki['index']}") | |
| await mark_rate_limited(key) | |
| continue | |
| if r.status_code != 200: | |
| log(f"STREAM HTTP {r.status_code}: key#{ki['index']}") | |
| await mark_fail(key) | |
| continue | |
| hit_limit = False | |
| async for line in r.aiter_lines(): | |
| if not line: | |
| continue | |
| if line.strip() == "data: [DONE]": | |
| break | |
| raw = line[6:] if line.startswith("data: ") else line | |
| try: | |
| j = json.loads(raw) | |
| if "error" in j and "choices" not in j: | |
| if is_rate_limited_error_body(json.dumps(j)): | |
| log(f"MID-STREAM LIMIT: key#{ki['index']}") | |
| hit_limit = True | |
| break | |
| except Exception: | |
| pass | |
| yield line + "\n\n" | |
| if hit_limit: | |
| await mark_rate_limited(key) | |
| continue | |
| yield "data: [DONE]\n\n" | |
| await mark_ok(key) | |
| return # sukses | |
| except Exception as e: | |
| log(f"STREAM EXCEPTION: key#{ki['index']} - {e}") | |
| await mark_fail(key) | |
| finally: | |
| await release_key(key) | |
| return StreamingResponse(stream_gen(), media_type="text/event-stream") | |
| # ===================================================== | |
| # /v1/messages (Anthropic-compatible) | |
| # FIXED: Full tool calling support | |
| # ===================================================== | |
| async def anthropic_messages(req: Request): | |
| if not auth_ok(req): | |
| return JSONResponse( | |
| {"type": "error", "error": {"type": "authentication_error", "message": "Unauthorized"}}, | |
| status_code=401 | |
| ) | |
| try: | |
| body = await req.json() | |
| except ClientDisconnect: | |
| return Response(status_code=499) | |
| except Exception: | |
| return JSONResponse( | |
| {"type": "error", "error": {"type": "invalid_request_error", "message": "Bad JSON"}}, | |
| status_code=400 | |
| ) | |
| is_stream = body.get("stream", False) | |
| original_model = body.get("model", DEFAULT_MODEL) | |
| cerebras_model = map_model(original_model) | |
| max_tokens = body.get("max_tokens", 4096) | |
| # Build messages list (OpenAI format) | |
| messages = [] | |
| if body.get("system"): | |
| sys_content = body["system"] | |
| if isinstance(sys_content, list): | |
| sys_content = "".join(x.get("text", "") for x in sys_content if x.get("type") == "text") | |
| messages.append({"role": "system", "content": sys_content}) | |
| # FIX: Convert Anthropic messages → OpenAI format (dengan tool_use dan tool_result support) | |
| converted = convert_anthropic_messages_to_openai(body.get("messages", [])) | |
| messages.extend(converted) | |
| # Token limiting | |
| messages = truncate_messages(messages, MAX_REQUEST_TOKENS) | |
| log(f"[ANT] Model: {cerebras_model}, Tokens: {count_messages_tokens(messages)}") | |
| cerebras_body = { | |
| "model": cerebras_model, | |
| "messages": messages, | |
| "stream": is_stream, | |
| "max_completion_tokens": min(max_tokens, 8192), | |
| } | |
| if "temperature" in body: | |
| cerebras_body["temperature"] = body["temperature"] | |
| if "top_p" in body: | |
| cerebras_body["top_p"] = body["top_p"] | |
| # FIX: Forward tools dari Anthropic → OpenAI format | |
| if body.get("tools"): | |
| cerebras_body["tools"] = anthropic_tools_to_openai(body["tools"]) | |
| if body.get("tool_choice"): | |
| cerebras_body["tool_choice"] = anthropic_tool_choice_to_openai(body["tool_choice"]) | |
| # ----------------------------------------- | |
| # NON STREAM | |
| # ----------------------------------------- | |
| if not is_stream: | |
| tried = set() | |
| for _ in range(len(API_KEYS)): | |
| key = await wait_for_free_key(exclude=tried) | |
| if not key: | |
| break | |
| tried.add(key) | |
| ki = key_status[key] | |
| log(f"ANTHROPIC NON-STREAM: key#{ki['index']}") | |
| try: | |
| async with httpx.AsyncClient(timeout=180) as client: | |
| r = await client.post( | |
| f"{CEREBRAS_BASE_URL}/chat/completions", | |
| json=cerebras_body, | |
| headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| ) | |
| if is_rate_limited_status(r.status_code) or (r.status_code != 200 and is_rate_limited_error_body(r.text)): | |
| log(f"RATE LIMITED: key#{ki['index']}") | |
| await mark_rate_limited(key) | |
| continue | |
| if r.status_code != 200: | |
| log(f"HTTP {r.status_code}: key#{ki['index']} - {r.text[:200]}") | |
| await mark_fail(key) | |
| continue | |
| data = r.json() | |
| # FIX: Convert OpenAI response → Anthropic format (including tool_calls) | |
| out = openai_response_to_anthropic(data, original_model) | |
| await mark_ok(key) | |
| return JSONResponse(out) | |
| except Exception as e: | |
| log(f"Exception: key#{ki['index']} - {e}") | |
| await mark_fail(key) | |
| finally: | |
| await release_key(key) | |
| return JSONResponse( | |
| {"type": "error", "error": {"type": "api_error", "message": "All keys failed"}}, | |
| status_code=500 | |
| ) | |
| # ----------------------------------------- | |
| # STREAM — Anthropic SSE format, infinite loop | |
| # FIX: Handle tool_calls streaming | |
| # ----------------------------------------- | |
| async def anthropic_stream_gen(): | |
| exclude = set() | |
| msg_id = "msg_" + uuid.uuid4().hex[:10] | |
| sent_header = False | |
| while True: | |
| key, exclude = await get_key_infinite(exclude=exclude) | |
| ki = key_status[key] | |
| log(f"ANTHROPIC STREAM: key#{ki['index']}") | |
| try: | |
| async with httpx.AsyncClient(timeout=None) as client: | |
| async with client.stream( | |
| "POST", | |
| f"{CEREBRAS_BASE_URL}/chat/completions", | |
| json=cerebras_body, | |
| headers={"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| ) as r: | |
| if is_rate_limited_status(r.status_code): | |
| log(f"STREAM RATE LIMITED: key#{ki['index']}") | |
| await mark_rate_limited(key) | |
| continue | |
| if r.status_code != 200: | |
| log(f"STREAM HTTP {r.status_code}: key#{ki['index']}") | |
| await mark_fail(key) | |
| continue | |
| # Kirim Anthropic envelope header (sekali saja) | |
| if not sent_header: | |
| sent_header = True | |
| yield sse({ | |
| "type": "message_start", | |
| "message": { | |
| "id": msg_id, | |
| "type": "message", | |
| "role": "assistant", | |
| "model": original_model, | |
| "content": [], | |
| "stop_reason": None, | |
| "stop_sequence": None, | |
| "usage": {"input_tokens": 0, "output_tokens": 0} | |
| } | |
| }) | |
| # Content block text (index 0) | |
| yield sse({ | |
| "type": "content_block_start", | |
| "index": 0, | |
| "content_block": {"type": "text", "text": ""} | |
| }) | |
| hit_limit = False | |
| output_tokens = 0 | |
| # Tracking tool calls yang sedang di-stream | |
| # tool_index_map: openai tool index -> anthropic block index | |
| tool_index_map = {} | |
| next_block_index = 1 # 0 = text block | |
| # Buffer untuk accumulate tool arguments per tool index | |
| tool_arg_buffers = {} | |
| finish_reason = None | |
| async for line in r.aiter_lines(): | |
| if not line: | |
| continue | |
| if line.strip() == "data: [DONE]": | |
| break | |
| raw = line[6:] if line.startswith("data: ") else line | |
| # Parse chunk | |
| try: | |
| j = json.loads(raw) | |
| except json.JSONDecodeError: | |
| continue | |
| # Cek error dari API (bukan model output) | |
| if "error" in j and "choices" not in j: | |
| err_str = json.dumps(j) | |
| if is_rate_limited_error_body(err_str): | |
| log(f"MID-STREAM LIMIT: key#{ki['index']}") | |
| hit_limit = True | |
| else: | |
| log(f"MID-STREAM API ERROR: {err_str[:200]}") | |
| break | |
| choices = j.get("choices", []) | |
| if not choices: | |
| # Cek usage | |
| if j.get("usage"): | |
| output_tokens = j["usage"].get("completion_tokens", output_tokens) | |
| continue | |
| choice = choices[0] | |
| delta = choice.get("delta", {}) | |
| finish_reason = choice.get("finish_reason") or finish_reason | |
| # Usage update | |
| if j.get("usage"): | |
| output_tokens = j["usage"].get("completion_tokens", output_tokens) | |
| # ---- TEXT CONTENT ---- | |
| text_token = delta.get("content") or "" | |
| if text_token: | |
| yield sse({ | |
| "type": "content_block_delta", | |
| "index": 0, | |
| "delta": {"type": "text_delta", "text": text_token} | |
| }) | |
| # ---- TOOL CALLS ---- | |
| # FIX UTAMA: Handle tool_calls dari streaming response | |
| tool_calls_delta = delta.get("tool_calls") or [] | |
| for tc_delta in tool_calls_delta: | |
| tc_idx = tc_delta.get("index", 0) | |
| # Kalau tool call baru (ada id dan nama) | |
| if tc_delta.get("id") or tc_delta.get("function", {}).get("name"): | |
| if tc_idx not in tool_index_map: | |
| # Assign block index baru untuk tool ini | |
| block_idx = next_block_index | |
| next_block_index += 1 | |
| tool_index_map[tc_idx] = block_idx | |
| tool_arg_buffers[tc_idx] = "" | |
| # Kirim content_block_start untuk tool_use | |
| yield sse({ | |
| "type": "content_block_start", | |
| "index": block_idx, | |
| "content_block": { | |
| "type": "tool_use", | |
| "id": tc_delta.get("id", "toolu_" + uuid.uuid4().hex[:10]), | |
| "name": tc_delta.get("function", {}).get("name", ""), | |
| "input": {} | |
| } | |
| }) | |
| # Stream arguments sebagai input_json_delta | |
| fn_delta = tc_delta.get("function", {}) | |
| args_chunk = fn_delta.get("arguments", "") | |
| if args_chunk and tc_idx in tool_index_map: | |
| tool_arg_buffers[tc_idx] += args_chunk | |
| block_idx = tool_index_map[tc_idx] | |
| yield sse({ | |
| "type": "content_block_delta", | |
| "index": block_idx, | |
| "delta": {"type": "input_json_delta", "partial_json": args_chunk} | |
| }) | |
| if hit_limit: | |
| await mark_rate_limited(key) | |
| continue | |
| # Tutup text block | |
| yield sse({"type": "content_block_stop", "index": 0}) | |
| # Tutup semua tool use blocks | |
| for tc_idx, block_idx in tool_index_map.items(): | |
| yield sse({"type": "content_block_stop", "index": block_idx}) | |
| # Determine stop_reason | |
| if finish_reason == "tool_calls" or tool_index_map: | |
| stop_reason = "tool_use" | |
| elif finish_reason == "length": | |
| stop_reason = "max_tokens" | |
| else: | |
| stop_reason = "end_turn" | |
| yield sse({ | |
| "type": "message_delta", | |
| "delta": {"stop_reason": stop_reason, "stop_sequence": None}, | |
| "usage": {"output_tokens": output_tokens} | |
| }) | |
| yield sse({"type": "message_stop"}) | |
| await mark_ok(key) | |
| return # sukses, keluar dari infinite loop | |
| except Exception as e: | |
| log(f"STREAM EXCEPTION: key#{ki['index']} - {e}") | |
| await mark_fail(key) | |
| finally: | |
| await release_key(key) | |
| # Fallback: kalau entah bagaimana keluar dari while True tanpa return | |
| if not sent_header: | |
| yield sse({ | |
| "type": "message_start", | |
| "message": { | |
| "id": msg_id, "type": "message", "role": "assistant", | |
| "model": original_model, "content": [], "stop_reason": None, | |
| "stop_sequence": None, "usage": {"input_tokens": 0, "output_tokens": 0} | |
| } | |
| }) | |
| yield sse({"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}) | |
| yield sse({"type": "content_block_stop", "index": 0}) | |
| yield sse({"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": 0}}) | |
| yield sse({"type": "message_stop"}) | |
| return StreamingResponse(anthropic_stream_gen(), media_type="text/event-stream") |