Track 1: purpose_agent/llm_backend.py
Browse files- purpose_agent/llm_backend.py +92 -2
purpose_agent/llm_backend.py
CHANGED
|
@@ -44,6 +44,21 @@ class LLMBackend(ABC):
|
|
| 44 |
constrained generation (used by the Purpose Function for reliable scoring).
|
| 45 |
"""
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
@abstractmethod
|
| 48 |
def generate(
|
| 49 |
self,
|
|
@@ -148,7 +163,7 @@ class HFInferenceBackend(LLMBackend):
|
|
| 148 |
max_tokens=max_tokens,
|
| 149 |
stop=stop or [],
|
| 150 |
)
|
| 151 |
-
return response.choices[0].message.content
|
| 152 |
|
| 153 |
def generate_structured(
|
| 154 |
self,
|
|
@@ -234,7 +249,7 @@ class OpenAICompatibleBackend(LLMBackend):
|
|
| 234 |
max_tokens=max_tokens,
|
| 235 |
stop=stop,
|
| 236 |
)
|
| 237 |
-
return response.choices[0].message.content
|
| 238 |
|
| 239 |
def generate_structured(
|
| 240 |
self,
|
|
@@ -361,3 +376,78 @@ class MockLLMBackend(LLMBackend):
|
|
| 361 |
else:
|
| 362 |
result[key] = f"mock_{key}"
|
| 363 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
constrained generation (used by the Purpose Function for reliable scoring).
|
| 45 |
"""
|
| 46 |
|
| 47 |
+
@staticmethod
|
| 48 |
+
def _strip_thinking(text: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Strip <think>...</think> tags from model output.
|
| 51 |
+
|
| 52 |
+
Many reasoning models (Qwen3, DeepSeek-R1, etc.) wrap their
|
| 53 |
+
chain-of-thought in <think> tags. We keep only the final answer.
|
| 54 |
+
"""
|
| 55 |
+
import re
|
| 56 |
+
# Remove <think>...</think> blocks (greedy, handles multiline)
|
| 57 |
+
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
|
| 58 |
+
# Also handle unclosed <think> tags (model cut off mid-thought)
|
| 59 |
+
cleaned = re.sub(r'<think>.*$', '', cleaned, flags=re.DOTALL)
|
| 60 |
+
return cleaned.strip()
|
| 61 |
+
|
| 62 |
@abstractmethod
|
| 63 |
def generate(
|
| 64 |
self,
|
|
|
|
| 163 |
max_tokens=max_tokens,
|
| 164 |
stop=stop or [],
|
| 165 |
)
|
| 166 |
+
return self._strip_thinking(response.choices[0].message.content or "")
|
| 167 |
|
| 168 |
def generate_structured(
|
| 169 |
self,
|
|
|
|
| 249 |
max_tokens=max_tokens,
|
| 250 |
stop=stop,
|
| 251 |
)
|
| 252 |
+
return self._strip_thinking(response.choices[0].message.content or "")
|
| 253 |
|
| 254 |
def generate_structured(
|
| 255 |
self,
|
|
|
|
| 376 |
else:
|
| 377 |
result[key] = f"mock_{key}"
|
| 378 |
return result
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
# ---------------------------------------------------------------------------
|
| 382 |
+
# Multi-Provider Router
|
| 383 |
+
# ---------------------------------------------------------------------------
|
| 384 |
+
|
| 385 |
+
# Provider → (base_url, env_var_for_key)
|
| 386 |
+
_PROVIDER_MAP = {
|
| 387 |
+
"groq": ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
|
| 388 |
+
"openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"),
|
| 389 |
+
"together": ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
|
| 390 |
+
"fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
|
| 391 |
+
"deepseek": ("https://api.deepseek.com/v1", "DEEPSEEK_API_KEY"),
|
| 392 |
+
"mistral": ("https://api.mistral.ai/v1", "MISTRAL_API_KEY"),
|
| 393 |
+
"cerebras": ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def resolve_backend(spec: str, api_key: str | None = None) -> LLMBackend:
|
| 398 |
+
"""
|
| 399 |
+
Resolve a 'provider:model' string into an LLMBackend.
|
| 400 |
+
|
| 401 |
+
Supports every major inference provider via OpenAI-compatible APIs,
|
| 402 |
+
plus Ollama for local models and HF for HuggingFace Inference.
|
| 403 |
+
|
| 404 |
+
Examples:
|
| 405 |
+
resolve_backend("groq:llama-3.3-70b-versatile")
|
| 406 |
+
resolve_backend("openai:gpt-4o")
|
| 407 |
+
resolve_backend("ollama:qwen3:1.7b")
|
| 408 |
+
resolve_backend("hf:Qwen/Qwen3-32B")
|
| 409 |
+
resolve_backend("together:meta-llama/Llama-3.3-70B-Instruct-Turbo")
|
| 410 |
+
resolve_backend("deepseek:deepseek-chat")
|
| 411 |
+
|
| 412 |
+
For local models without a provider prefix:
|
| 413 |
+
resolve_backend("qwen3:1.7b") # auto-detects Ollama
|
| 414 |
+
resolve_backend("gpt-4o") # auto-detects OpenAI
|
| 415 |
+
resolve_backend("Qwen/Qwen3-32B") # auto-detects HF
|
| 416 |
+
"""
|
| 417 |
+
if ":" in spec:
|
| 418 |
+
parts = spec.split(":", 1)
|
| 419 |
+
provider = parts[0].lower()
|
| 420 |
+
|
| 421 |
+
if provider == "ollama":
|
| 422 |
+
from purpose_agent.slm_backends import OllamaBackend
|
| 423 |
+
return OllamaBackend(model=parts[1])
|
| 424 |
+
|
| 425 |
+
if provider == "hf":
|
| 426 |
+
return HFInferenceBackend(model_id=parts[1], api_key=api_key)
|
| 427 |
+
|
| 428 |
+
if provider in _PROVIDER_MAP:
|
| 429 |
+
base_url, env_var = _PROVIDER_MAP[provider]
|
| 430 |
+
key = api_key or os.environ.get(env_var, "")
|
| 431 |
+
if not key:
|
| 432 |
+
raise ValueError(
|
| 433 |
+
f"No API key for {provider}. Set {env_var} environment variable "
|
| 434 |
+
f"or pass api_key= parameter."
|
| 435 |
+
)
|
| 436 |
+
return OpenAICompatibleBackend(
|
| 437 |
+
model=parts[1], base_url=base_url, api_key=key,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Not a known provider — might be Ollama model like "qwen3:1.7b"
|
| 441 |
+
from purpose_agent.slm_backends import OllamaBackend
|
| 442 |
+
return OllamaBackend(model=spec)
|
| 443 |
+
|
| 444 |
+
# No colon — auto-detect
|
| 445 |
+
if spec.startswith("gpt-") or spec.startswith("o1") or spec.startswith("o3"):
|
| 446 |
+
key = api_key or os.environ.get("OPENAI_API_KEY", "")
|
| 447 |
+
return OpenAICompatibleBackend(model=spec, api_key=key)
|
| 448 |
+
|
| 449 |
+
if "/" in spec:
|
| 450 |
+
return HFInferenceBackend(model_id=spec, api_key=api_key)
|
| 451 |
+
|
| 452 |
+
from purpose_agent.slm_backends import OllamaBackend
|
| 453 |
+
return OllamaBackend(model=spec)
|