| import json |
| import traceback |
| from fastapi import FastAPI, HTTPException |
| from dotenv import load_dotenv |
| import os |
| import re |
| from huggingface_hub import ChatCompletionInputMessage, ChatCompletionInputTool |
| import litellm |
| litellm.ssl_verify = False |
| from litellm.router import Router |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from typing import List, Optional, Literal, Type, Union |
|
|
| load_dotenv() |
|
|
| app = FastAPI() |
|
|
| api_keys = [] |
|
|
| for k,v in os.environ.items(): |
| if re.match(r'^GROQ_\d+$', k): |
| api_keys.append(v) |
|
|
| models_data = { |
| "allam-2-7b": {"rpm": 30, "rpd": 7000, "tpm": 6000}, |
| "compound-beta": {"rpm": 15, "rpd": 200, "tpm": 70000}, |
| "compound-beta-mini": {"rpm": 15, "rpd": 200, "tpm": 70000}, |
| "deepseek-r1-distill-llama-70b": {"rpm": 30, "rpd": 1000, "tpm": 6000}, |
| "gemma2-9b-it": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, |
| "llama-3.1-8b-instant": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
| "llama-3.3-70b-versatile": {"rpm": 30, "rpd": 1000, "tpm": 12000, "tpd": 100000}, |
| "llama3-70b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
| "llama3-8b-8192": {"rpm": 30, "rpd": 14400, "tpm": 6000, "tpd": 500000}, |
| "meta-llama/llama-4-maverick-17b-128e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 6000, "tpd": None}, |
| "meta-llama/llama-4-scout-17b-16e-instruct": {"rpm": 30, "rpd": 1000, "tpm": 30000, "tpd": None}, |
| "meta-llama/llama-guard-4-12b": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": 500000}, |
| "meta-llama/llama-prompt-guard-2-22m": {"rpm": 30, "rpd": 14400, "tpm": 15000, "tpd": None}, |
| "meta-llama/llama-prompt-guard-2-86m": {"rpm": 30, "rpd": 14400, "tpm": None, "tpd": None}, |
| } |
|
|
| model_list = [ |
| { |
| "model_name": f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}", |
| "litellm_params": { |
| "model": f"groq/{model_name}", |
| "api_key": api_key |
| }, |
| "timeout": 120, |
| "max_retries": 5 |
| } |
| for model_name, config in models_data.items() |
| for key_idx, api_key in enumerate(api_keys) |
| ] |
|
|
| def generate_fallbacks_per_key(): |
| fallbacks = [] |
| excluded_models = {"compound-beta", "compound-beta-mini"} |
| |
| for model_name in models_data.keys(): |
| if model_name in excluded_models: |
| continue |
| |
| |
| for key_idx in range(len(api_keys)): |
| current_model = f"{model_name}_{key_idx}" if key_idx != 0 else f"{model_name}" |
| fallback_versions = [ |
| f"{model_name}_{other_key_idx}" if other_key_idx != 0 else f"{model_name}" |
| for other_key_idx in range(len(api_keys)) |
| if other_key_idx != key_idx |
| ] |
| |
| |
| fallbacks.append({ |
| current_model: fallback_versions |
| }) |
| |
| return fallbacks |
|
|
| fallbacks = generate_fallbacks_per_key() |
|
|
| router = Router( |
| model_list=model_list, |
| fallbacks=fallbacks, |
| num_retries=5, |
| retry_after=10 |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_credentials=True, |
| allow_headers=["*"], |
| allow_methods=["GET", "POST"], |
| allow_origins=["*"] |
| ) |
|
|
| class ChatRequest(BaseModel): |
| models: List[str] |
| messages: List[ChatCompletionInputMessage] |
| tools: Optional[List[ChatCompletionInputTool]] = None |
| temperature: Optional[float] = None |
| max_tokens: Optional[int] = None |
| n: Optional[int] = None |
| stream: Optional[bool] = None |
| stop: Optional[List[str]] = None |
|
|
| def clean_message(msg) -> dict: |
| """Convertit un message en dictionnaire, gérant différents types d'objets""" |
| if hasattr(msg, 'model_dump'): |
| |
| return {k: v for k, v in msg.model_dump().items() if v is not None} |
| elif hasattr(msg, '__dict__'): |
| |
| return {k: v for k, v in msg.__dict__.items() if v is not None} |
| elif isinstance(msg, dict): |
| |
| return {k: v for k, v in msg.items() if v is not None} |
| else: |
| |
| return dict(msg) |
|
|
| @app.get("/") |
| def main_page(): |
| return {"status": "ok"} |
|
|
| @app.post("/chat") |
| def chat_with_groq(req: ChatRequest): |
| models = req.models |
| if len(models) == 1 and (models[0] == "" or models[0] not in models_data.keys()): |
| raise HTTPException(400, detail="Empty model field") |
| messages = [clean_message(m) for m in req.messages] |
| if len(models) == 1: |
| try: |
| resp = router.completion(model=models[0], messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) |
| print("Asked to", models[0], ":", messages) |
| return {"error": False, "content": resp.choices[0].message.content} |
| except Exception as e: |
| traceback.print_exception(e) |
| return {"error": True, "content": "Aucune clé ne fonctionne avec le modèle sélectionné, patientez ...."} |
| else: |
| for model in models: |
| if model not in models_data.keys(): |
| print(f"Erreur: {model} n'existe pas") |
| continue |
| try: |
| resp = router.completion(model=model, messages=messages, **req.model_dump(exclude={"models", "messages"}, exclude_defaults=True, exclude_none=True)) |
| print("Asked to", models[0], ":", messages) |
| return {"error": False, "content": resp.choices[0].message.content} |
| except Exception as e: |
| traceback.print_exception(e) |
| continue |
| return {"error": True, "content": "Tous les modèles n'ont pas fonctionné avec les différentes clé, patientez ...."} |