muthuk1 commited on
Commit
2cd4ca5
Β·
verified Β·
1 Parent(s): ddb116f

Add Python universal LLM layer with LiteLLM supporting 12 providers + Ollama

Browse files
Files changed (1) hide show
  1. graphrag/layers/universal_llm.py +321 -0
graphrag/layers/universal_llm.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Universal LLM Layer β€” LiteLLM-powered multi-provider support
3
+ =============================================================
4
+ Supports 12 providers through a single interface:
5
+ OpenAI, Anthropic, Gemini, Mistral, Cohere, Ollama (local),
6
+ OpenRouter, Groq, xAI, Together AI, HuggingFace, DeepSeek
7
+
8
+ Uses LiteLLM for unified API, falls back to direct OpenAI SDK
9
+ if LiteLLM is not installed.
10
+ """
11
+ import json
12
+ import logging
13
+ import os
14
+ import time
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # ── Provider Registry ─────────────────────────────────
21
+
22
+ PROVIDERS = {
23
+ "openai": {
24
+ "name": "OpenAI",
25
+ "litellm_prefix": "openai",
26
+ "default_model": "gpt-4o-mini",
27
+ "api_key_env": "OPENAI_API_KEY",
28
+ "cost_input": 0.00015, "cost_output": 0.0006,
29
+ },
30
+ "anthropic": {
31
+ "name": "Anthropic Claude",
32
+ "litellm_prefix": "anthropic",
33
+ "default_model": "claude-sonnet-4-20250514",
34
+ "api_key_env": "ANTHROPIC_API_KEY",
35
+ "cost_input": 0.003, "cost_output": 0.015,
36
+ },
37
+ "gemini": {
38
+ "name": "Google Gemini",
39
+ "litellm_prefix": "gemini",
40
+ "default_model": "gemini-2.0-flash",
41
+ "api_key_env": "GEMINI_API_KEY",
42
+ "cost_input": 0.0001, "cost_output": 0.0004,
43
+ },
44
+ "mistral": {
45
+ "name": "Mistral AI",
46
+ "litellm_prefix": "mistral",
47
+ "default_model": "mistral-large-latest",
48
+ "api_key_env": "MISTRAL_API_KEY",
49
+ "cost_input": 0.002, "cost_output": 0.006,
50
+ },
51
+ "cohere": {
52
+ "name": "Cohere",
53
+ "litellm_prefix": "cohere_chat",
54
+ "default_model": "command-r-plus",
55
+ "api_key_env": "COHERE_API_KEY",
56
+ "cost_input": 0.0025, "cost_output": 0.01,
57
+ },
58
+ "ollama": {
59
+ "name": "Ollama (Local)",
60
+ "litellm_prefix": "ollama_chat",
61
+ "default_model": "llama3.2",
62
+ "api_key_env": "",
63
+ "api_base": "http://localhost:11434",
64
+ "cost_input": 0, "cost_output": 0,
65
+ "is_local": True,
66
+ },
67
+ "openrouter": {
68
+ "name": "OpenRouter",
69
+ "litellm_prefix": "openrouter",
70
+ "default_model": "meta-llama/llama-3.3-70b-instruct",
71
+ "api_key_env": "OPENROUTER_API_KEY",
72
+ "cost_input": 0.0004, "cost_output": 0.0004,
73
+ },
74
+ "groq": {
75
+ "name": "Groq",
76
+ "litellm_prefix": "groq",
77
+ "default_model": "llama-3.3-70b-versatile",
78
+ "api_key_env": "GROQ_API_KEY",
79
+ "cost_input": 0.00059, "cost_output": 0.00079,
80
+ },
81
+ "xai": {
82
+ "name": "xAI Grok",
83
+ "litellm_prefix": "xai",
84
+ "default_model": "grok-3-mini",
85
+ "api_key_env": "XAI_API_KEY",
86
+ "cost_input": 0.0003, "cost_output": 0.0005,
87
+ },
88
+ "together": {
89
+ "name": "Together AI",
90
+ "litellm_prefix": "together_ai",
91
+ "default_model": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
92
+ "api_key_env": "TOGETHER_API_KEY",
93
+ "cost_input": 0.00088, "cost_output": 0.00088,
94
+ },
95
+ "huggingface": {
96
+ "name": "HuggingFace Inference",
97
+ "litellm_prefix": "huggingface",
98
+ "default_model": "meta-llama/Llama-3.3-70B-Instruct",
99
+ "api_key_env": "HF_TOKEN",
100
+ "cost_input": 0, "cost_output": 0,
101
+ },
102
+ "deepseek": {
103
+ "name": "DeepSeek",
104
+ "litellm_prefix": "deepseek",
105
+ "default_model": "deepseek-chat",
106
+ "api_key_env": "DEEPSEEK_API_KEY",
107
+ "cost_input": 0.00014, "cost_output": 0.00028,
108
+ },
109
+ }
110
+
111
+
112
+ @dataclass
113
+ class LLMResponse:
114
+ """Universal LLM response."""
115
+ content: str = ""
116
+ input_tokens: int = 0
117
+ output_tokens: int = 0
118
+ total_tokens: int = 0
119
+ latency_ms: float = 0.0
120
+ cost_usd: float = 0.0
121
+ model: str = ""
122
+ provider: str = ""
123
+
124
+
125
+ class UniversalLLM:
126
+ """
127
+ Universal LLM client supporting 12 providers.
128
+ Uses LiteLLM when available, falls back to OpenAI SDK.
129
+ """
130
+
131
+ def __init__(self, provider: str = "openai", model: str = None,
132
+ api_key: str = None, api_base: str = None):
133
+ self.provider_id = provider
134
+ self.provider_config = PROVIDERS.get(provider, PROVIDERS["openai"])
135
+ self.model = model or self.provider_config["default_model"]
136
+ self._api_key = api_key
137
+ self._api_base = api_base
138
+ self._litellm = None
139
+ self._openai_client = None
140
+ self._anthropic_client = None
141
+
142
+ def initialize(self):
143
+ """Initialize the appropriate SDK."""
144
+ # Try LiteLLM first (universal)
145
+ try:
146
+ import litellm
147
+ self._litellm = litellm
148
+ litellm.drop_params = True # auto-drop unsupported params
149
+ logger.info(f"LiteLLM initialized for {self.provider_id}/{self.model}")
150
+ return
151
+ except ImportError:
152
+ pass
153
+
154
+ # Fall back to direct SDK
155
+ if self.provider_id == "anthropic":
156
+ try:
157
+ from anthropic import Anthropic
158
+ key = self._api_key or os.getenv(self.provider_config["api_key_env"], "")
159
+ self._anthropic_client = Anthropic(api_key=key)
160
+ logger.info(f"Anthropic SDK initialized: {self.model}")
161
+ return
162
+ except ImportError:
163
+ pass
164
+
165
+ # OpenAI SDK (works for OpenAI, Ollama, Groq, Together, etc.)
166
+ try:
167
+ from openai import OpenAI
168
+ api_key_env = self.provider_config.get("api_key_env", "")
169
+ key = self._api_key or os.getenv(api_key_env, "") or "ollama"
170
+ base = self._api_base or self.provider_config.get("api_base", "")
171
+
172
+ base_urls = {
173
+ "openai": "https://api.openai.com/v1",
174
+ "gemini": "https://generativelanguage.googleapis.com/v1beta/openai/",
175
+ "mistral": "https://api.mistral.ai/v1",
176
+ "cohere": "https://api.cohere.ai/compatibility/v1",
177
+ "ollama": "http://localhost:11434/v1",
178
+ "openrouter": "https://openrouter.ai/api/v1",
179
+ "groq": "https://api.groq.com/openai/v1",
180
+ "xai": "https://api.x.ai/v1",
181
+ "together": "https://api.together.xyz/v1",
182
+ "huggingface": "https://api-inference.huggingface.co/v1",
183
+ "deepseek": "https://api.deepseek.com/v1",
184
+ }
185
+ base_url = base or base_urls.get(self.provider_id, "https://api.openai.com/v1")
186
+
187
+ self._openai_client = OpenAI(base_url=base_url, api_key=key)
188
+ logger.info(f"OpenAI-compat SDK initialized for {self.provider_id}: {base_url}")
189
+ except ImportError:
190
+ logger.warning("No SDK available. Install: pip install openai litellm anthropic")
191
+
192
+ def generate(self, messages: List[Dict[str, str]],
193
+ temperature: float = 0, max_tokens: int = 1024,
194
+ json_mode: bool = False) -> LLMResponse:
195
+ """Generate a response using the configured provider."""
196
+ start = time.perf_counter()
197
+ cost_in = self.provider_config.get("cost_input", 0)
198
+ cost_out = self.provider_config.get("cost_output", 0)
199
+
200
+ # ── LiteLLM path ──────────────────────────────
201
+ if self._litellm:
202
+ return self._call_litellm(messages, temperature, max_tokens, json_mode, start, cost_in, cost_out)
203
+
204
+ # ── Anthropic direct path ─────────────────────
205
+ if self._anthropic_client:
206
+ return self._call_anthropic(messages, temperature, max_tokens, start, cost_in, cost_out)
207
+
208
+ # ── OpenAI-compat path ────────────────────────
209
+ if self._openai_client:
210
+ return self._call_openai(messages, temperature, max_tokens, json_mode, start, cost_in, cost_out)
211
+
212
+ # ── Mock fallback ─────────────────────────────
213
+ return LLMResponse(
214
+ content="[No LLM SDK available. Install: pip install openai]",
215
+ input_tokens=50, output_tokens=20, total_tokens=70,
216
+ latency_ms=100, cost_usd=0, model=self.model, provider=self.provider_id,
217
+ )
218
+
219
+ def _call_litellm(self, messages, temp, max_tok, json_mode, start, ci, co):
220
+ prefix = self.provider_config["litellm_prefix"]
221
+ model_str = f"{prefix}/{self.model}"
222
+ kwargs = {"model": model_str, "messages": messages,
223
+ "temperature": temp, "max_tokens": max_tok}
224
+ if json_mode:
225
+ kwargs["response_format"] = {"type": "json_object"}
226
+ if self.provider_config.get("api_base"):
227
+ kwargs["api_base"] = self.provider_config["api_base"]
228
+
229
+ resp = self._litellm.completion(**kwargs)
230
+ elapsed = (time.perf_counter() - start) * 1000
231
+ u = resp.usage
232
+ return LLMResponse(
233
+ content=resp.choices[0].message.content or "",
234
+ input_tokens=u.prompt_tokens, output_tokens=u.completion_tokens,
235
+ total_tokens=u.total_tokens, latency_ms=elapsed,
236
+ cost_usd=(u.prompt_tokens / 1000 * ci + u.completion_tokens / 1000 * co),
237
+ model=self.model, provider=self.provider_id,
238
+ )
239
+
240
+ def _call_anthropic(self, messages, temp, max_tok, start, ci, co):
241
+ sys_msg = next((m["content"] for m in messages if m["role"] == "system"), None)
242
+ user_msgs = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
243
+ kwargs = {"model": self.model, "max_tokens": max_tok,
244
+ "temperature": temp, "messages": user_msgs}
245
+ if sys_msg:
246
+ kwargs["system"] = sys_msg
247
+ msg = self._anthropic_client.messages.create(**kwargs)
248
+ elapsed = (time.perf_counter() - start) * 1000
249
+ content = msg.content[0].text if msg.content and msg.content[0].type == "text" else ""
250
+ return LLMResponse(
251
+ content=content,
252
+ input_tokens=msg.usage.input_tokens, output_tokens=msg.usage.output_tokens,
253
+ total_tokens=msg.usage.input_tokens + msg.usage.output_tokens,
254
+ latency_ms=elapsed,
255
+ cost_usd=(msg.usage.input_tokens / 1000 * ci + msg.usage.output_tokens / 1000 * co),
256
+ model=self.model, provider=self.provider_id,
257
+ )
258
+
259
+ def _call_openai(self, messages, temp, max_tok, json_mode, start, ci, co):
260
+ kwargs = {"model": self.model, "messages": messages,
261
+ "temperature": temp, "max_tokens": max_tok}
262
+ if json_mode:
263
+ kwargs["response_format"] = {"type": "json_object"}
264
+ resp = self._openai_client.chat.completions.create(**kwargs)
265
+ elapsed = (time.perf_counter() - start) * 1000
266
+ u = resp.usage
267
+ return LLMResponse(
268
+ content=resp.choices[0].message.content or "",
269
+ input_tokens=u.prompt_tokens if u else 0,
270
+ output_tokens=u.completion_tokens if u else 0,
271
+ total_tokens=u.total_tokens if u else 0, latency_ms=elapsed,
272
+ cost_usd=((u.prompt_tokens if u else 0) / 1000 * ci + (u.completion_tokens if u else 0) / 1000 * co),
273
+ model=self.model, provider=self.provider_id,
274
+ )
275
+
276
+ # ── Convenience methods ──────────────────────────
277
+
278
+ def generate_answer(self, query, context, system_prompt=None):
279
+ if not system_prompt:
280
+ system_prompt = "Answer accurately using ONLY the provided context. Be concise."
281
+ return self.generate([
282
+ {"role": "system", "content": system_prompt},
283
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"},
284
+ ], max_tokens=512)
285
+
286
+ def extract_entities(self, text):
287
+ return self.generate([
288
+ {"role": "system", "content": 'Extract entities and relationships. Return JSON: {"entities": [{"name": "...", "type": "PERSON|ORG|LOCATION|EVENT|CONCEPT"}], "relations": [{"source": "...", "target": "...", "type": "...", "description": "..."}]}'},
289
+ {"role": "user", "content": text},
290
+ ], max_tokens=2048, json_mode=True)
291
+
292
+ def extract_keywords(self, query):
293
+ return self.generate([
294
+ {"role": "system", "content": 'Extract keywords. Return JSON: {"high_level": ["themes"], "low_level": ["entities"]}'},
295
+ {"role": "user", "content": query},
296
+ ], max_tokens=256, json_mode=True)
297
+
298
+
299
+ def get_available_providers() -> List[str]:
300
+ """Return list of provider IDs with valid API keys."""
301
+ available = []
302
+ for pid, cfg in PROVIDERS.items():
303
+ if cfg.get("is_local"):
304
+ available.append(pid)
305
+ elif not cfg.get("api_key_env"):
306
+ available.append(pid)
307
+ elif os.getenv(cfg["api_key_env"]):
308
+ available.append(pid)
309
+ return available
310
+
311
+
312
+ def check_ollama() -> Dict[str, Any]:
313
+ """Check if Ollama is running locally."""
314
+ import urllib.request
315
+ try:
316
+ req = urllib.request.Request("http://localhost:11434/api/tags", method="GET")
317
+ with urllib.request.urlopen(req, timeout=2) as resp:
318
+ data = json.loads(resp.read())
319
+ return {"ok": True, "models": [m["name"] for m in data.get("models", [])]}
320
+ except Exception:
321
+ return {"ok": False, "models": []}