Rohan03 commited on
Commit
73ecef8
·
verified ·
1 Parent(s): b199fa3

Add purpose_agent/llm_backend.py

Browse files
Files changed (1) hide show
  1. purpose_agent/llm_backend.py +363 -0
purpose_agent/llm_backend.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Backend — Swappable inference layer.
3
+
4
+ Supports: HuggingFace Inference Providers, OpenAI, Anthropic, local models,
5
+ or any custom backend. Swap by changing one constructor call.
6
+
7
+ Design: Abstract base class with structured output support.
8
+ Inspired by smolagents Model interface + HF Inference Providers API.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import logging
15
+ import os
16
+ from abc import ABC, abstractmethod
17
+ from dataclasses import dataclass, field
18
+ from typing import Any
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Message types (OpenAI-compatible chat format)
25
+ # ---------------------------------------------------------------------------
26
+
27
+ @dataclass
28
+ class ChatMessage:
29
+ role: str # "system", "user", "assistant"
30
+ content: str
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Abstract LLM Backend
35
+ # ---------------------------------------------------------------------------
36
+
37
+ class LLMBackend(ABC):
38
+ """
39
+ Abstract LLM backend. All modules call this — swap the implementation
40
+ to change the underlying model without touching any other code.
41
+
42
+ Subclasses must implement `generate()` which takes messages and returns
43
+ a string. Optionally implement `generate_structured()` for JSON-schema
44
+ constrained generation (used by the Purpose Function for reliable scoring).
45
+ """
46
+
47
+ @abstractmethod
48
+ def generate(
49
+ self,
50
+ messages: list[ChatMessage],
51
+ temperature: float = 0.7,
52
+ max_tokens: int = 2048,
53
+ stop: list[str] | None = None,
54
+ ) -> str:
55
+ """Generate a text completion from chat messages."""
56
+ ...
57
+
58
+ def generate_structured(
59
+ self,
60
+ messages: list[ChatMessage],
61
+ schema: dict[str, Any],
62
+ temperature: float = 0.3,
63
+ max_tokens: int = 1024,
64
+ ) -> dict[str, Any]:
65
+ """
66
+ Generate with JSON schema constraint.
67
+
68
+ Default implementation: append schema instruction to last message
69
+ and parse JSON from response. Override for native structured output.
70
+ """
71
+ schema_instruction = (
72
+ f"\n\nYou MUST respond with valid JSON matching this schema:\n"
73
+ f"```json\n{json.dumps(schema, indent=2)}\n```\n"
74
+ f"Respond ONLY with the JSON object, no other text."
75
+ )
76
+ augmented = list(messages)
77
+ last = augmented[-1]
78
+ augmented[-1] = ChatMessage(
79
+ role=last.role, content=last.content + schema_instruction
80
+ )
81
+ raw = self.generate(augmented, temperature=temperature, max_tokens=max_tokens)
82
+
83
+ # Extract JSON from response (handle markdown code blocks)
84
+ text = raw.strip()
85
+ if text.startswith("```"):
86
+ lines = text.split("\n")
87
+ # Remove first and last ``` lines
88
+ json_lines = []
89
+ inside = False
90
+ for line in lines:
91
+ if line.strip().startswith("```") and not inside:
92
+ inside = True
93
+ continue
94
+ elif line.strip() == "```" and inside:
95
+ break
96
+ elif inside:
97
+ json_lines.append(line)
98
+ text = "\n".join(json_lines)
99
+
100
+ return json.loads(text)
101
+
102
+
103
+ # ---------------------------------------------------------------------------
104
+ # HuggingFace Inference Provider Backend
105
+ # ---------------------------------------------------------------------------
106
+
107
+ class HFInferenceBackend(LLMBackend):
108
+ """
109
+ Uses huggingface_hub InferenceClient for HF Inference Providers.
110
+
111
+ Supports: Cerebras, Novita, Fireworks, Together, SambaNova, etc.
112
+ Models: Qwen, Llama, Mistral, DeepSeek — anything on HF Hub.
113
+
114
+ Example:
115
+ backend = HFInferenceBackend(
116
+ model_id="Qwen/Qwen3-32B",
117
+ provider="cerebras",
118
+ )
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ model_id: str = "Qwen/Qwen3-32B",
124
+ provider: str = "auto",
125
+ api_key: str | None = None,
126
+ ):
127
+ from huggingface_hub import InferenceClient
128
+
129
+ self.model_id = model_id
130
+ self.provider = provider
131
+ self.client = InferenceClient(
132
+ provider=provider,
133
+ api_key=api_key or os.environ.get("HF_TOKEN"),
134
+ )
135
+
136
+ def generate(
137
+ self,
138
+ messages: list[ChatMessage],
139
+ temperature: float = 0.7,
140
+ max_tokens: int = 2048,
141
+ stop: list[str] | None = None,
142
+ ) -> str:
143
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
144
+ response = self.client.chat_completion(
145
+ model=self.model_id,
146
+ messages=msg_dicts,
147
+ temperature=temperature,
148
+ max_tokens=max_tokens,
149
+ stop=stop or [],
150
+ )
151
+ return response.choices[0].message.content
152
+
153
+ def generate_structured(
154
+ self,
155
+ messages: list[ChatMessage],
156
+ schema: dict[str, Any],
157
+ temperature: float = 0.3,
158
+ max_tokens: int = 1024,
159
+ ) -> dict[str, Any]:
160
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
161
+ response = self.client.chat_completion(
162
+ model=self.model_id,
163
+ messages=msg_dicts,
164
+ temperature=temperature,
165
+ max_tokens=max_tokens,
166
+ response_format={
167
+ "type": "json_schema",
168
+ "json_schema": {"schema": schema},
169
+ },
170
+ )
171
+ return json.loads(response.choices[0].message.content)
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # OpenAI-Compatible Backend (OpenAI, Azure, vLLM, Ollama, LiteLLM)
176
+ # ---------------------------------------------------------------------------
177
+
178
+ class OpenAICompatibleBackend(LLMBackend):
179
+ """
180
+ Works with any OpenAI-compatible API endpoint.
181
+
182
+ Examples:
183
+ # OpenAI
184
+ backend = OpenAICompatibleBackend(model="gpt-4o")
185
+
186
+ # Local Ollama
187
+ backend = OpenAICompatibleBackend(
188
+ model="llama3.2",
189
+ base_url="http://localhost:11434/v1",
190
+ api_key="ollama",
191
+ )
192
+
193
+ # vLLM server
194
+ backend = OpenAICompatibleBackend(
195
+ model="meta-llama/Llama-3.2-3B-Instruct",
196
+ base_url="http://localhost:8000/v1",
197
+ api_key="token-placeholder",
198
+ )
199
+
200
+ # HF Inference via OpenAI SDK (for structured output with .parse())
201
+ backend = OpenAICompatibleBackend(
202
+ model="Qwen/Qwen3-32B",
203
+ base_url="https://router.huggingface.co/cerebras/v1",
204
+ api_key=os.environ["HF_TOKEN"],
205
+ )
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ model: str = "gpt-4o",
211
+ base_url: str | None = None,
212
+ api_key: str | None = None,
213
+ ):
214
+ from openai import OpenAI
215
+
216
+ self.model = model
217
+ self.client = OpenAI(
218
+ base_url=base_url,
219
+ api_key=api_key or os.environ.get("OPENAI_API_KEY"),
220
+ )
221
+
222
+ def generate(
223
+ self,
224
+ messages: list[ChatMessage],
225
+ temperature: float = 0.7,
226
+ max_tokens: int = 2048,
227
+ stop: list[str] | None = None,
228
+ ) -> str:
229
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
230
+ response = self.client.chat.completions.create(
231
+ model=self.model,
232
+ messages=msg_dicts,
233
+ temperature=temperature,
234
+ max_tokens=max_tokens,
235
+ stop=stop,
236
+ )
237
+ return response.choices[0].message.content
238
+
239
+ def generate_structured(
240
+ self,
241
+ messages: list[ChatMessage],
242
+ schema: dict[str, Any],
243
+ temperature: float = 0.3,
244
+ max_tokens: int = 1024,
245
+ ) -> dict[str, Any]:
246
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
247
+ response = self.client.chat.completions.create(
248
+ model=self.model,
249
+ messages=msg_dicts,
250
+ temperature=temperature,
251
+ max_tokens=max_tokens,
252
+ response_format={
253
+ "type": "json_schema",
254
+ "json_schema": {"name": "purpose_score", "schema": schema},
255
+ },
256
+ )
257
+ return json.loads(response.choices[0].message.content)
258
+
259
+
260
+ # ---------------------------------------------------------------------------
261
+ # Mock Backend (for testing without API calls)
262
+ # ---------------------------------------------------------------------------
263
+
264
+ class MockLLMBackend(LLMBackend):
265
+ """
266
+ Deterministic mock backend for testing the framework without LLM calls.
267
+
268
+ Returns canned responses based on keywords in the prompt, or a default.
269
+ You can register custom response handlers.
270
+ """
271
+
272
+ def __init__(self):
273
+ self._handlers: list[tuple[str, str | callable]] = []
274
+ self._structured_default: dict[str, Any] = {}
275
+ self._call_log: list[dict] = []
276
+
277
+ def register_handler(
278
+ self, keyword: str, response: str | callable
279
+ ) -> "MockLLMBackend":
280
+ """Add a keyword-matched response handler. Checked in order."""
281
+ self._handlers.append((keyword, response))
282
+ return self
283
+
284
+ def set_structured_default(self, default: dict[str, Any]) -> "MockLLMBackend":
285
+ """Set the default response for structured generation."""
286
+ self._structured_default = default
287
+ return self
288
+
289
+ @property
290
+ def call_log(self) -> list[dict]:
291
+ return self._call_log
292
+
293
+ def generate(
294
+ self,
295
+ messages: list[ChatMessage],
296
+ temperature: float = 0.7,
297
+ max_tokens: int = 2048,
298
+ stop: list[str] | None = None,
299
+ ) -> str:
300
+ full_text = " ".join(m.content for m in messages)
301
+ self._call_log.append({
302
+ "method": "generate",
303
+ "messages": [{"role": m.role, "content": m.content[:200]} for m in messages],
304
+ })
305
+
306
+ for keyword, response in self._handlers:
307
+ if keyword.lower() in full_text.lower():
308
+ if callable(response):
309
+ return response(messages)
310
+ return response
311
+
312
+ # Default: echo the last user message with a generic response
313
+ last_user = next(
314
+ (m.content for m in reversed(messages) if m.role == "user"),
315
+ "no input",
316
+ )
317
+ return f"[MockLLM] Acknowledged: {last_user[:100]}"
318
+
319
+ def generate_structured(
320
+ self,
321
+ messages: list[ChatMessage],
322
+ schema: dict[str, Any],
323
+ temperature: float = 0.3,
324
+ max_tokens: int = 1024,
325
+ ) -> dict[str, Any]:
326
+ self._call_log.append({
327
+ "method": "generate_structured",
328
+ "schema_keys": list(schema.get("properties", {}).keys()),
329
+ })
330
+ # Try keyword handlers first — they may return JSON strings or dicts
331
+ full_text = " ".join(m.content for m in messages)
332
+ for keyword, response in self._handlers:
333
+ if keyword.lower() in full_text.lower():
334
+ if callable(response):
335
+ result = response(messages)
336
+ else:
337
+ result = response
338
+ # If handler returned a string, try to parse as JSON
339
+ if isinstance(result, str):
340
+ try:
341
+ return json.loads(result)
342
+ except (json.JSONDecodeError, TypeError):
343
+ pass
344
+ elif isinstance(result, dict):
345
+ return result
346
+
347
+ # Fall back to structured default
348
+ if self._structured_default:
349
+ return self._structured_default
350
+ # Build a minimal valid response from the schema
351
+ props = schema.get("properties", {})
352
+ result = {}
353
+ for key, prop in props.items():
354
+ ptype = prop.get("type", "string")
355
+ if ptype == "number":
356
+ result[key] = 5.0
357
+ elif ptype == "integer":
358
+ result[key] = 5
359
+ elif ptype == "boolean":
360
+ result[key] = True
361
+ else:
362
+ result[key] = f"mock_{key}"
363
+ return result