purpose-agent / purpose_agent /llm_backend.py
Rohan03's picture
v3.0.0 Production Release: Hardened framework, strict tool validation, test suite robustification
36d2671
"""
LLM Backend — Swappable inference layer.
Supports: HuggingFace Inference Providers, OpenAI, Anthropic, local models,
or any custom backend. Swap by changing one constructor call.
Design: Abstract base class with structured output support.
Inspired by smolagents Model interface + HF Inference Providers API.
"""
from __future__ import annotations
import json
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Message types (OpenAI-compatible chat format)
# ---------------------------------------------------------------------------
@dataclass
class ChatMessage:
role: str # "system", "user", "assistant"
content: str
# ---------------------------------------------------------------------------
# Abstract LLM Backend
# ---------------------------------------------------------------------------
class LLMBackend(ABC):
"""
Abstract LLM backend. All modules call this — swap the implementation
to change the underlying model without touching any other code.
Subclasses must implement `generate()` which takes messages and returns
a string. Optionally implement `generate_structured()` for JSON-schema
constrained generation (used by the Purpose Function for reliable scoring).
"""
@staticmethod
def _strip_thinking(text: str) -> str:
"""
Strip <think>...</think> tags from model output.
Many reasoning models (Qwen3, DeepSeek-R1, etc.) wrap their
chain-of-thought in <think> tags. We keep only the final answer.
"""
import re
# Remove <think>...</think> blocks (greedy, handles multiline)
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
# Also handle unclosed <think> tags (model cut off mid-thought)
cleaned = re.sub(r'<think>.*$', '', cleaned, flags=re.DOTALL)
return cleaned.strip()
@abstractmethod
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
"""Generate a text completion from chat messages."""
...
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
"""
Generate with JSON schema constraint.
Default implementation: append schema instruction to last message
and parse JSON from response. Override for native structured output.
"""
schema_instruction = (
f"\n\nYou MUST respond with valid JSON matching this schema:\n"
f"```json\n{json.dumps(schema, indent=2)}\n```\n"
f"Respond ONLY with the JSON object, no other text."
)
augmented = list(messages)
last = augmented[-1]
augmented[-1] = ChatMessage(
role=last.role, content=last.content + schema_instruction
)
raw = self.generate(augmented, temperature=temperature, max_tokens=max_tokens)
# Extract JSON from response (handle markdown code blocks)
text = raw.strip()
if text.startswith("```"):
lines = text.split("\n")
# Remove first and last ``` lines
json_lines = []
inside = False
for line in lines:
if line.strip().startswith("```") and not inside:
inside = True
continue
elif line.strip() == "```" and inside:
break
elif inside:
json_lines.append(line)
text = "\n".join(json_lines)
return json.loads(text)
# ---------------------------------------------------------------------------
# HuggingFace Inference Provider Backend
# ---------------------------------------------------------------------------
class HFInferenceBackend(LLMBackend):
"""
Uses huggingface_hub InferenceClient for HF Inference Providers.
Supports: Cerebras, Novita, Fireworks, Together, SambaNova, etc.
Models: Qwen, Llama, Mistral, DeepSeek — anything on HF Hub.
Example:
backend = HFInferenceBackend(
model_id="Qwen/Qwen3-32B",
provider="cerebras",
)
"""
def __init__(
self,
model_id: str = "Qwen/Qwen3-32B",
provider: str = "auto",
api_key: str | None = None,
):
from huggingface_hub import InferenceClient
self.model_id = model_id
self.provider = provider
self.client = InferenceClient(
provider=provider,
api_key=api_key or os.environ.get("HF_TOKEN"),
)
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat_completion(
model=self.model_id,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
stop=stop or [],
)
return self._strip_thinking(response.choices[0].message.content or "")
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat_completion(
model=self.model_id,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
response_format={
"type": "json_schema",
"json_schema": {"schema": schema},
},
)
return json.loads(response.choices[0].message.content)
# ---------------------------------------------------------------------------
# OpenAI-Compatible Backend (OpenAI, Azure, vLLM, Ollama, LiteLLM)
# ---------------------------------------------------------------------------
class OpenAICompatibleBackend(LLMBackend):
"""
Works with any OpenAI-compatible API endpoint.
Examples:
# OpenAI
backend = OpenAICompatibleBackend(model="gpt-4o")
# Local Ollama
backend = OpenAICompatibleBackend(
model="llama3.2",
base_url="http://localhost:11434/v1",
api_key="ollama",
)
# vLLM server
backend = OpenAICompatibleBackend(
model="meta-llama/Llama-3.2-3B-Instruct",
base_url="http://localhost:8000/v1",
api_key="token-placeholder",
)
# HF Inference via OpenAI SDK (for structured output with .parse())
backend = OpenAICompatibleBackend(
model="Qwen/Qwen3-32B",
base_url="https://router.huggingface.co/cerebras/v1",
api_key=os.environ["HF_TOKEN"],
)
"""
def __init__(
self,
model: str = "gpt-4o",
base_url: str | None = None,
api_key: str | None = None,
timeout: float = 60.0,
):
from openai import OpenAI
self.model = model
self.client = OpenAI(
base_url=base_url,
api_key=api_key or os.environ.get("OPENAI_API_KEY"),
timeout=timeout,
)
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat.completions.create(
model=self.model,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
)
return self._strip_thinking(response.choices[0].message.content or "")
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat.completions.create(
model=self.model,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
response_format={
"type": "json_schema",
"json_schema": {"name": "purpose_score", "schema": schema},
},
)
return json.loads(response.choices[0].message.content)
# ---------------------------------------------------------------------------
# Mock Backend (for testing without API calls)
# ---------------------------------------------------------------------------
class MockLLMBackend(LLMBackend):
"""
Deterministic mock backend for testing the framework without LLM calls.
Returns canned responses based on keywords in the prompt, or a default.
You can register custom response handlers.
"""
def __init__(self):
self._handlers: list[tuple[str, str | callable]] = []
self._structured_default: dict[str, Any] = {}
self._call_log: list[dict] = []
def register_handler(
self, keyword: str, response: str | callable
) -> "MockLLMBackend":
"""Add a keyword-matched response handler. Checked in order."""
self._handlers.append((keyword, response))
return self
def set_structured_default(self, default: dict[str, Any]) -> "MockLLMBackend":
"""Set the default response for structured generation."""
self._structured_default = default
return self
@property
def call_log(self) -> list[dict]:
return self._call_log
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
full_text = " ".join(m.content for m in messages)
self._call_log.append({
"method": "generate",
"messages": [{"role": m.role, "content": m.content[:200]} for m in messages],
})
for keyword, response in self._handlers:
if keyword.lower() in full_text.lower():
if callable(response):
return response(messages)
return response
# Default: echo the last user message with a generic response
last_user = next(
(m.content for m in reversed(messages) if m.role == "user"),
"no input",
)
return f"[MockLLM] Acknowledged: {last_user[:100]}"
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
self._call_log.append({
"method": "generate_structured",
"schema_keys": list(schema.get("properties", {}).keys()),
})
# Try keyword handlers first — they may return JSON strings or dicts
full_text = " ".join(m.content for m in messages)
for keyword, response in self._handlers:
if keyword.lower() in full_text.lower():
if callable(response):
result = response(messages)
else:
result = response
# If handler returned a string, try to parse as JSON
if isinstance(result, str):
try:
return json.loads(result)
except (json.JSONDecodeError, TypeError):
pass
elif isinstance(result, dict):
return result
# Fall back to structured default
if self._structured_default:
return self._structured_default
# Build a minimal valid response from the schema
props = schema.get("properties", {})
result = {}
for key, prop in props.items():
ptype = prop.get("type", "string")
if ptype == "number":
result[key] = 5.0
elif ptype == "integer":
result[key] = 5
elif ptype == "boolean":
result[key] = True
else:
result[key] = f"mock_{key}"
return result
# ---------------------------------------------------------------------------
# Multi-Provider Router
# ---------------------------------------------------------------------------
# Provider → (base_url, env_var_for_key)
_PROVIDER_MAP = {
"groq": ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
"openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"),
"together": ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
"fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
"deepseek": ("https://api.deepseek.com/v1", "DEEPSEEK_API_KEY"),
"mistral": ("https://api.mistral.ai/v1", "MISTRAL_API_KEY"),
"cerebras": ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
}
def resolve_backend(spec: str, api_key: str | None = None) -> LLMBackend:
"""
Resolve a 'provider:model' string into an LLMBackend.
Supports every major inference provider via OpenAI-compatible APIs,
plus Ollama for local models and HF for HuggingFace Inference.
Examples:
resolve_backend("groq:llama-3.3-70b-versatile")
resolve_backend("openai:gpt-4o")
resolve_backend("ollama:qwen3:1.7b")
resolve_backend("hf:Qwen/Qwen3-32B")
resolve_backend("together:meta-llama/Llama-3.3-70B-Instruct-Turbo")
resolve_backend("deepseek:deepseek-chat")
For local models without a provider prefix:
resolve_backend("qwen3:1.7b") # auto-detects Ollama
resolve_backend("gpt-4o") # auto-detects OpenAI
resolve_backend("Qwen/Qwen3-32B") # auto-detects HF
"""
if ":" in spec:
parts = spec.split(":", 1)
provider = parts[0].lower()
if provider == "ollama":
from purpose_agent.slm_backends import OllamaBackend
return OllamaBackend(model=parts[1])
if provider == "hf":
return HFInferenceBackend(model_id=parts[1], api_key=api_key)
if provider in _PROVIDER_MAP:
base_url, env_var = _PROVIDER_MAP[provider]
key = api_key or os.environ.get(env_var, "")
if not key:
raise ValueError(
f"No API key for {provider}. Set {env_var} environment variable "
f"or pass api_key= parameter."
)
return OpenAICompatibleBackend(
model=parts[1], base_url=base_url, api_key=key,
)
# Not a known provider — might be Ollama model like "qwen3:1.7b"
from purpose_agent.slm_backends import OllamaBackend
return OllamaBackend(model=spec)
# No colon — auto-detect
if spec.startswith("gpt-") or spec.startswith("o1") or spec.startswith("o3"):
key = api_key or os.environ.get("OPENAI_API_KEY", "")
return OpenAICompatibleBackend(model=spec, api_key=key)
if "/" in spec:
return HFInferenceBackend(model_id=spec, api_key=api_key)
from purpose_agent.slm_backends import OllamaBackend
return OllamaBackend(model=spec)