Spaces:
Sleeping
Sleeping
File size: 6,034 Bytes
bb0c63f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
LLM Client for the Secure Gateway with multi-provider fallback
"""
import os
import requests
import json
import time
class LLMClient:
def __init__(self):
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
self.groq_api_key = os.getenv("GROQ_API_KEY")
self.openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
self.gemini_model = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-exp")
self.groq_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
self.openrouter_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-exp:free")
self.providers = []
if self.gemini_api_key:
self.providers.append({"name": "gemini", "key": self.gemini_api_key, "model": self.gemini_model})
if self.groq_api_key:
self.providers.append({"name": "groq", "key": self.groq_api_key, "model": self.groq_model})
if self.openrouter_api_key:
self.providers.append({"name": "openrouter", "key": self.openrouter_api_key, "model": self.openrouter_model})
async def call_llm_provider(self, provider_name: str, api_key: str, model: str, prompt: str, max_tokens: int, temperature: float):
"""Call a specific LLM provider"""
headers = {"Content-Type": "application/json"}
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": temperature,
}
if provider_name == "gemini":
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
# Gemini API expects 'contents' not 'messages'
payload["contents"] = payload.pop("messages")
# Gemini content structure is slightly different
payload["contents"] = [{
"parts": [{
"text": prompt
}]
}]
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if data and "candidates" in data and data["candidates"]:
# Gemini's response structure for text is complex, often in 'parts' of 'content'
first_candidate = data["candidates"][0]
if "content" in first_candidate and "parts" in first_candidate["content"]:
for part in first_candidate["content"]["parts"]:
if "text" in part:
return part["text"], None
return None, "No text content found in Gemini response."
except requests.exceptions.RequestException:
return None, "Gemini API request failed"
elif provider_name == "groq":
url = "https://api.groq.com/openai/v1/chat/completions"
headers["Authorization"] = f"Bearer {api_key}"
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if data and "choices" in data and data["choices"]:
return data["choices"][0]["message"]["content"], None
return None, "No content found in Groq response."
except requests.exceptions.RequestException:
return None, "Groq API request failed"
elif provider_name == "openrouter":
url = "https://openrouter.ai/api/v1/chat/completions"
headers["Authorization"] = f"Bearer {api_key}"
headers["HTTP-Referer"] = "http://localhost:8000" # Replace with your app URL
headers["X-Title"] = "Secure LLM Router PoC"
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if data and "choices" in data and data["choices"]:
return data["choices"][0]["message"]["content"], None
return None, "No content found in OpenRouter response."
except requests.exceptions.RequestException:
return None, "OpenRouter API request failed"
else:
return None, f"Unknown LLM provider: {provider_name}"
async def query_llm_cascade(self, prompt: str, max_tokens: int, temperature: float):
"""Query LLM with cascade fallback across providers
Returns: (response, provider_name, latency_ms, error, cascade_path)
"""
cascade_path = []
for provider in self.providers:
provider_name = provider["name"]
start_time = time.perf_counter()
response_content, error = await self.call_llm_provider(
provider_name=provider["name"],
api_key=provider["key"],
model=provider["model"],
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature
)
latency_ms = int((time.perf_counter() - start_time) * 1000)
if response_content:
cascade_path.append({
"provider": provider_name,
"model": provider["model"],
"status": "success",
"reason": None,
"latency_ms": latency_ms
})
return response_content, provider_name, latency_ms, None, cascade_path
else:
cascade_path.append({
"provider": provider_name,
"model": provider["model"],
"status": "failed",
"reason": error,
"latency_ms": latency_ms
})
return None, None, 0, "All LLM providers failed.", cascade_path
# Instantiate the client (can be imported and used in app.py)
llm_client = LLMClient() |