Rohan03 commited on
Commit
7eacaee
·
verified ·
1 Parent(s): 8f2700b

v0.2.0: Add purpose_agent/slm_backends.py

Browse files
Files changed (1) hide show
  1. purpose_agent/slm_backends.py +446 -0
purpose_agent/slm_backends.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SLM-Native Backends — First-class support for Small Language Models.
3
+
4
+ Purpose Agent is the world's first agentic framework designed natively for SLMs.
5
+ These backends handle the unique challenges of small models:
6
+ - Grammar-constrained JSON output (SLMs can't reliably produce JSON from prompts alone)
7
+ - Prompt compression for small context windows (8K-32K)
8
+ - Adaptive prompting (shorter system prompts, schema-first format)
9
+ - Token budget management
10
+
11
+ Supported backends:
12
+ - OllamaBackend: Local serving via Ollama (CPU/GPU, any GGUF model)
13
+ - LlamaCppBackend: Direct llama-cpp-python (CPU/Apple Silicon, GGUF)
14
+ - TransformersBackend: HuggingFace transformers (GPU, native weights)
15
+
16
+ All backends implement the same LLMBackend interface — swap freely.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+ import re
25
+ from typing import Any, AsyncIterator, Iterator
26
+
27
+ from purpose_agent.llm_backend import ChatMessage, LLMBackend
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # SLM Prompt Compressor — reduces prompt size for small context windows
34
+ # ---------------------------------------------------------------------------
35
+
36
+ class SLMPromptCompressor:
37
+ """
38
+ Compresses prompts for small context windows without losing critical info.
39
+
40
+ Strategies (from TinyAgent arxiv:2409.00608 + LLMLingua-2 arxiv:2403.12968):
41
+ 1. Schema-first: Move JSON schema to top, compress descriptions
42
+ 2. History truncation: Summarize old steps, keep recent ones verbatim
43
+ 3. Example reduction: Fewer few-shot examples for SLMs
44
+ 4. Whitespace stripping: Remove unnecessary formatting
45
+
46
+ No external dependencies — pure Python compression.
47
+ For better compression, install llmlingua: pip install llmlingua
48
+ """
49
+
50
+ def __init__(self, max_tokens: int = 4096, aggressive: bool = False):
51
+ self.max_tokens = max_tokens
52
+ self.aggressive = aggressive
53
+
54
+ def compress(self, text: str, budget: int | None = None) -> str:
55
+ """Compress text to fit within token budget."""
56
+ budget = budget or self.max_tokens
57
+ # Rough estimate: 1 token ≈ 4 chars
58
+ char_budget = budget * 4
59
+
60
+ if len(text) <= char_budget:
61
+ return text
62
+
63
+ compressed = text
64
+ # Stage 1: Strip excessive whitespace
65
+ compressed = re.sub(r'\n{3,}', '\n\n', compressed)
66
+ compressed = re.sub(r'[ \t]{2,}', ' ', compressed)
67
+ compressed = re.sub(r'^\s+', '', compressed, flags=re.MULTILINE)
68
+
69
+ if len(compressed) <= char_budget:
70
+ return compressed
71
+
72
+ # Stage 2: Shorten verbose sections
73
+ if self.aggressive:
74
+ # Remove markdown formatting
75
+ compressed = re.sub(r'\*\*([^*]+)\*\*', r'\1', compressed)
76
+ compressed = re.sub(r'#{1,3}\s+', '', compressed)
77
+ # Shorten common verbose phrases
78
+ replacements = {
79
+ "You MUST respond with": "Respond with",
80
+ "Based on the current state and your goal, ": "",
81
+ "Respond in this exact JSON format:": "JSON format:",
82
+ "Step-by-step justification": "Justification",
83
+ "Specific observable state changes": "State changes",
84
+ }
85
+ for old, new in replacements.items():
86
+ compressed = compressed.replace(old, new)
87
+
88
+ if len(compressed) <= char_budget:
89
+ return compressed
90
+
91
+ # Stage 3: Truncate from middle (keep start + end)
92
+ keep_start = char_budget * 2 // 3
93
+ keep_end = char_budget // 3
94
+ compressed = compressed[:keep_start] + "\n...[truncated]...\n" + compressed[-keep_end:]
95
+
96
+ return compressed
97
+
98
+ def compress_messages(
99
+ self, messages: list[ChatMessage], budget: int | None = None
100
+ ) -> list[ChatMessage]:
101
+ """Compress a message list to fit within token budget."""
102
+ budget = budget or self.max_tokens
103
+ total_chars = sum(len(m.content) for m in messages)
104
+ char_budget = budget * 4
105
+
106
+ if total_chars <= char_budget:
107
+ return messages
108
+
109
+ result = []
110
+ # Always keep system prompt (compress it), always keep last user message
111
+ for i, msg in enumerate(messages):
112
+ if msg.role == "system":
113
+ result.append(ChatMessage(
114
+ role="system",
115
+ content=self.compress(msg.content, budget=budget // 3),
116
+ ))
117
+ elif i == len(messages) - 1:
118
+ # Last message — keep more of it
119
+ result.append(ChatMessage(
120
+ role=msg.role,
121
+ content=self.compress(msg.content, budget=budget // 2),
122
+ ))
123
+ else:
124
+ result.append(ChatMessage(
125
+ role=msg.role,
126
+ content=self.compress(msg.content, budget=budget // 4),
127
+ ))
128
+ return result
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # Ollama Backend — Best for local SLMs
133
+ # ---------------------------------------------------------------------------
134
+
135
+ class OllamaBackend(LLMBackend):
136
+ """
137
+ Local model serving via Ollama with grammar-constrained JSON output.
138
+
139
+ Ollama's grammar engine (via llama.cpp) forces valid JSON output from
140
+ ANY model — even tiny ones that can't produce reliable JSON from prompts.
141
+ This is the key advantage for SLM agent use.
142
+
143
+ Setup:
144
+ 1. Install Ollama: https://ollama.ai
145
+ 2. Pull a model: ollama pull qwen3:1.7b
146
+ 3. Use this backend:
147
+
148
+ Example:
149
+ backend = OllamaBackend(model="qwen3:1.7b") # 1.7B params, runs on CPU
150
+ backend = OllamaBackend(model="llama3.2:1b") # 1B params, ultra-light
151
+ backend = OllamaBackend(model="phi4-mini") # 3.8B, best tool-use
152
+ backend = OllamaBackend(model="smollm2:1.7b") # HF native SLM
153
+
154
+ Also works with large models:
155
+ backend = OllamaBackend(model="qwen3:32b") # Full LLM
156
+ """
157
+
158
+ def __init__(
159
+ self,
160
+ model: str = "qwen3:1.7b",
161
+ host: str = "http://localhost:11434",
162
+ context_window: int = 8192,
163
+ compress_prompts: bool = True,
164
+ num_ctx: int | None = None,
165
+ ):
166
+ self.model = model
167
+ self.host = host
168
+ self.context_window = context_window
169
+ self.compress_prompts = compress_prompts
170
+ self.num_ctx = num_ctx or context_window
171
+ self.compressor = SLMPromptCompressor(
172
+ max_tokens=context_window, aggressive=(context_window <= 8192)
173
+ )
174
+ self._token_count = 0
175
+
176
+ def _get_client(self):
177
+ """Lazy import ollama client."""
178
+ try:
179
+ from ollama import Client
180
+ return Client(host=self.host)
181
+ except ImportError:
182
+ raise ImportError(
183
+ "Ollama client not installed. Run: pip install ollama\n"
184
+ "Also install Ollama server: https://ollama.ai"
185
+ )
186
+
187
+ def generate(
188
+ self,
189
+ messages: list[ChatMessage],
190
+ temperature: float = 0.7,
191
+ max_tokens: int = 2048,
192
+ stop: list[str] | None = None,
193
+ ) -> str:
194
+ client = self._get_client()
195
+
196
+ if self.compress_prompts:
197
+ messages = self.compressor.compress_messages(messages, self.context_window)
198
+
199
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
200
+
201
+ response = client.chat(
202
+ model=self.model,
203
+ messages=msg_dicts,
204
+ options={
205
+ "temperature": temperature,
206
+ "num_predict": max_tokens,
207
+ "num_ctx": self.num_ctx,
208
+ "stop": stop or [],
209
+ },
210
+ )
211
+
212
+ content = response.message.content or ""
213
+ # Track tokens for cost tracking
214
+ self._token_count += response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
215
+ return content
216
+
217
+ def generate_structured(
218
+ self,
219
+ messages: list[ChatMessage],
220
+ schema: dict[str, Any],
221
+ temperature: float = 0.3,
222
+ max_tokens: int = 1024,
223
+ ) -> dict[str, Any]:
224
+ """
225
+ Grammar-constrained JSON generation.
226
+
227
+ Ollama uses llama.cpp's grammar engine to FORCE valid JSON output
228
+ matching the schema. This works even with tiny models that can't
229
+ produce valid JSON from prompts alone.
230
+ """
231
+ client = self._get_client()
232
+
233
+ if self.compress_prompts:
234
+ messages = self.compressor.compress_messages(messages, self.context_window)
235
+
236
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
237
+
238
+ response = client.chat(
239
+ model=self.model,
240
+ messages=msg_dicts,
241
+ format=schema, # Grammar-constrained output!
242
+ options={
243
+ "temperature": temperature,
244
+ "num_predict": max_tokens,
245
+ "num_ctx": self.num_ctx,
246
+ },
247
+ )
248
+
249
+ content = response.message.content or "{}"
250
+ self._token_count += response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
251
+ return json.loads(content)
252
+
253
+ def generate_stream(
254
+ self,
255
+ messages: list[ChatMessage],
256
+ temperature: float = 0.7,
257
+ max_tokens: int = 2048,
258
+ ) -> Iterator[str]:
259
+ """Streaming generation — yields tokens as they're produced."""
260
+ client = self._get_client()
261
+
262
+ if self.compress_prompts:
263
+ messages = self.compressor.compress_messages(messages, self.context_window)
264
+
265
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
266
+
267
+ stream = client.chat(
268
+ model=self.model,
269
+ messages=msg_dicts,
270
+ stream=True,
271
+ options={
272
+ "temperature": temperature,
273
+ "num_predict": max_tokens,
274
+ "num_ctx": self.num_ctx,
275
+ },
276
+ )
277
+
278
+ for chunk in stream:
279
+ token = chunk.get("message", {}).get("content", "")
280
+ if token:
281
+ yield token
282
+
283
+ @property
284
+ def total_tokens(self) -> int:
285
+ return self._token_count
286
+
287
+
288
+ # ---------------------------------------------------------------------------
289
+ # LlamaCpp Backend — Direct CPU/Apple Silicon/GGUF
290
+ # ---------------------------------------------------------------------------
291
+
292
+ class LlamaCppBackend(LLMBackend):
293
+ """
294
+ Direct llama-cpp-python backend for GGUF models.
295
+
296
+ Best for: CPU inference, Apple Silicon, edge deployment, offline use.
297
+
298
+ Example:
299
+ backend = LlamaCppBackend(model_path="./qwen2.5-1.5b-instruct-q4_k_m.gguf")
300
+ backend = LlamaCppBackend(
301
+ model_path="./phi-4-mini-q4.gguf",
302
+ n_ctx=4096,
303
+ n_gpu_layers=35, # Offload to GPU
304
+ )
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ model_path: str,
310
+ n_ctx: int = 4096,
311
+ n_gpu_layers: int = 0,
312
+ verbose: bool = False,
313
+ ):
314
+ try:
315
+ from llama_cpp import Llama
316
+ except ImportError:
317
+ raise ImportError("llama-cpp-python not installed. Run: pip install llama-cpp-python")
318
+
319
+ self.model_path = model_path
320
+ self.llm = Llama(
321
+ model_path=model_path,
322
+ n_ctx=n_ctx,
323
+ n_gpu_layers=n_gpu_layers,
324
+ verbose=verbose,
325
+ )
326
+ self.compressor = SLMPromptCompressor(max_tokens=n_ctx, aggressive=True)
327
+ self._token_count = 0
328
+
329
+ def generate(
330
+ self,
331
+ messages: list[ChatMessage],
332
+ temperature: float = 0.7,
333
+ max_tokens: int = 2048,
334
+ stop: list[str] | None = None,
335
+ ) -> str:
336
+ messages = self.compressor.compress_messages(messages)
337
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
338
+
339
+ response = self.llm.create_chat_completion(
340
+ messages=msg_dicts,
341
+ temperature=temperature,
342
+ max_tokens=max_tokens,
343
+ stop=stop,
344
+ )
345
+
346
+ content = response["choices"][0]["message"]["content"] or ""
347
+ usage = response.get("usage", {})
348
+ self._token_count += usage.get("total_tokens", 0)
349
+ return content
350
+
351
+ def generate_structured(
352
+ self,
353
+ messages: list[ChatMessage],
354
+ schema: dict[str, Any],
355
+ temperature: float = 0.3,
356
+ max_tokens: int = 1024,
357
+ ) -> dict[str, Any]:
358
+ """Grammar-constrained JSON via llama.cpp GBNF grammar."""
359
+ from llama_cpp import LlamaGrammar
360
+
361
+ grammar = LlamaGrammar.from_json_schema(json.dumps(schema))
362
+ messages = self.compressor.compress_messages(messages)
363
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
364
+
365
+ response = self.llm.create_chat_completion(
366
+ messages=msg_dicts,
367
+ temperature=temperature,
368
+ max_tokens=max_tokens,
369
+ grammar=grammar,
370
+ )
371
+
372
+ content = response["choices"][0]["message"]["content"] or "{}"
373
+ usage = response.get("usage", {})
374
+ self._token_count += usage.get("total_tokens", 0)
375
+ return json.loads(content)
376
+
377
+ def generate_stream(
378
+ self,
379
+ messages: list[ChatMessage],
380
+ temperature: float = 0.7,
381
+ max_tokens: int = 2048,
382
+ ) -> Iterator[str]:
383
+ messages = self.compressor.compress_messages(messages)
384
+ msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
385
+
386
+ stream = self.llm.create_chat_completion(
387
+ messages=msg_dicts,
388
+ temperature=temperature,
389
+ max_tokens=max_tokens,
390
+ stream=True,
391
+ )
392
+
393
+ for chunk in stream:
394
+ delta = chunk.get("choices", [{}])[0].get("delta", {})
395
+ token = delta.get("content", "")
396
+ if token:
397
+ yield token
398
+
399
+ @property
400
+ def total_tokens(self) -> int:
401
+ return self._token_count
402
+
403
+
404
+ # ---------------------------------------------------------------------------
405
+ # Model Registry — Easy model selection for SLMs
406
+ # ---------------------------------------------------------------------------
407
+
408
+ # Recommended SLMs for agent tasks, ranked by capability
409
+ SLM_REGISTRY = {
410
+ # Model ID → (Ollama name, context window, description)
411
+ "phi-4-mini": ("phi4-mini", 16384, "3.8B, best schema compliance, Microsoft"),
412
+ "qwen3-1.7b": ("qwen3:1.7b", 32768, "1.7B, strong function calling, 32K context"),
413
+ "qwen3-0.6b": ("qwen3:0.6b", 32768, "0.6B, ultra-light, 32K context"),
414
+ "qwen2.5-1.5b": ("qwen2.5:1.5b", 32768, "1.5B, proven tool-use"),
415
+ "llama-3.2-3b": ("llama3.2:3b", 131072, "3B, 128K context, Meta"),
416
+ "llama-3.2-1b": ("llama3.2:1b", 131072, "1B, smallest Llama, 128K context"),
417
+ "smollm2-1.7b": ("smollm2:1.7b", 8192, "1.7B, HF native, 8K context (tight!)"),
418
+ "gemma-3-1b": ("gemma3:1b", 32768, "1B, Google, multimodal capable"),
419
+ }
420
+
421
+
422
+ def create_slm_backend(
423
+ model_key: str = "qwen3-1.7b",
424
+ host: str = "http://localhost:11434",
425
+ ) -> OllamaBackend:
426
+ """
427
+ Create an SLM backend from the registry.
428
+
429
+ Usage:
430
+ backend = create_slm_backend("phi-4-mini") # Best overall
431
+ backend = create_slm_backend("qwen3-0.6b") # Ultra-light
432
+ backend = create_slm_backend("llama-3.2-1b") # Smallest Llama
433
+ """
434
+ if model_key not in SLM_REGISTRY:
435
+ available = ", ".join(SLM_REGISTRY.keys())
436
+ raise ValueError(f"Unknown SLM '{model_key}'. Available: {available}")
437
+
438
+ ollama_name, ctx_window, desc = SLM_REGISTRY[model_key]
439
+ logger.info(f"Creating SLM backend: {model_key} ({desc})")
440
+
441
+ return OllamaBackend(
442
+ model=ollama_name,
443
+ host=host,
444
+ context_window=ctx_window,
445
+ compress_prompts=True,
446
+ )