purpose-agent / purpose_agent /llm_backend.py
Rohan03's picture
Add purpose_agent/llm_backend.py
73ecef8 verified
raw
history blame
12.1 kB
"""
LLM Backend — Swappable inference layer.
Supports: HuggingFace Inference Providers, OpenAI, Anthropic, local models,
or any custom backend. Swap by changing one constructor call.
Design: Abstract base class with structured output support.
Inspired by smolagents Model interface + HF Inference Providers API.
"""
from __future__ import annotations
import json
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Message types (OpenAI-compatible chat format)
# ---------------------------------------------------------------------------
@dataclass
class ChatMessage:
role: str # "system", "user", "assistant"
content: str
# ---------------------------------------------------------------------------
# Abstract LLM Backend
# ---------------------------------------------------------------------------
class LLMBackend(ABC):
"""
Abstract LLM backend. All modules call this — swap the implementation
to change the underlying model without touching any other code.
Subclasses must implement `generate()` which takes messages and returns
a string. Optionally implement `generate_structured()` for JSON-schema
constrained generation (used by the Purpose Function for reliable scoring).
"""
@abstractmethod
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
"""Generate a text completion from chat messages."""
...
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
"""
Generate with JSON schema constraint.
Default implementation: append schema instruction to last message
and parse JSON from response. Override for native structured output.
"""
schema_instruction = (
f"\n\nYou MUST respond with valid JSON matching this schema:\n"
f"```json\n{json.dumps(schema, indent=2)}\n```\n"
f"Respond ONLY with the JSON object, no other text."
)
augmented = list(messages)
last = augmented[-1]
augmented[-1] = ChatMessage(
role=last.role, content=last.content + schema_instruction
)
raw = self.generate(augmented, temperature=temperature, max_tokens=max_tokens)
# Extract JSON from response (handle markdown code blocks)
text = raw.strip()
if text.startswith("```"):
lines = text.split("\n")
# Remove first and last ``` lines
json_lines = []
inside = False
for line in lines:
if line.strip().startswith("```") and not inside:
inside = True
continue
elif line.strip() == "```" and inside:
break
elif inside:
json_lines.append(line)
text = "\n".join(json_lines)
return json.loads(text)
# ---------------------------------------------------------------------------
# HuggingFace Inference Provider Backend
# ---------------------------------------------------------------------------
class HFInferenceBackend(LLMBackend):
"""
Uses huggingface_hub InferenceClient for HF Inference Providers.
Supports: Cerebras, Novita, Fireworks, Together, SambaNova, etc.
Models: Qwen, Llama, Mistral, DeepSeek — anything on HF Hub.
Example:
backend = HFInferenceBackend(
model_id="Qwen/Qwen3-32B",
provider="cerebras",
)
"""
def __init__(
self,
model_id: str = "Qwen/Qwen3-32B",
provider: str = "auto",
api_key: str | None = None,
):
from huggingface_hub import InferenceClient
self.model_id = model_id
self.provider = provider
self.client = InferenceClient(
provider=provider,
api_key=api_key or os.environ.get("HF_TOKEN"),
)
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat_completion(
model=self.model_id,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
stop=stop or [],
)
return response.choices[0].message.content
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat_completion(
model=self.model_id,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
response_format={
"type": "json_schema",
"json_schema": {"schema": schema},
},
)
return json.loads(response.choices[0].message.content)
# ---------------------------------------------------------------------------
# OpenAI-Compatible Backend (OpenAI, Azure, vLLM, Ollama, LiteLLM)
# ---------------------------------------------------------------------------
class OpenAICompatibleBackend(LLMBackend):
"""
Works with any OpenAI-compatible API endpoint.
Examples:
# OpenAI
backend = OpenAICompatibleBackend(model="gpt-4o")
# Local Ollama
backend = OpenAICompatibleBackend(
model="llama3.2",
base_url="http://localhost:11434/v1",
api_key="ollama",
)
# vLLM server
backend = OpenAICompatibleBackend(
model="meta-llama/Llama-3.2-3B-Instruct",
base_url="http://localhost:8000/v1",
api_key="token-placeholder",
)
# HF Inference via OpenAI SDK (for structured output with .parse())
backend = OpenAICompatibleBackend(
model="Qwen/Qwen3-32B",
base_url="https://router.huggingface.co/cerebras/v1",
api_key=os.environ["HF_TOKEN"],
)
"""
def __init__(
self,
model: str = "gpt-4o",
base_url: str | None = None,
api_key: str | None = None,
):
from openai import OpenAI
self.model = model
self.client = OpenAI(
base_url=base_url,
api_key=api_key or os.environ.get("OPENAI_API_KEY"),
)
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat.completions.create(
model=self.model,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
)
return response.choices[0].message.content
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
response = self.client.chat.completions.create(
model=self.model,
messages=msg_dicts,
temperature=temperature,
max_tokens=max_tokens,
response_format={
"type": "json_schema",
"json_schema": {"name": "purpose_score", "schema": schema},
},
)
return json.loads(response.choices[0].message.content)
# ---------------------------------------------------------------------------
# Mock Backend (for testing without API calls)
# ---------------------------------------------------------------------------
class MockLLMBackend(LLMBackend):
"""
Deterministic mock backend for testing the framework without LLM calls.
Returns canned responses based on keywords in the prompt, or a default.
You can register custom response handlers.
"""
def __init__(self):
self._handlers: list[tuple[str, str | callable]] = []
self._structured_default: dict[str, Any] = {}
self._call_log: list[dict] = []
def register_handler(
self, keyword: str, response: str | callable
) -> "MockLLMBackend":
"""Add a keyword-matched response handler. Checked in order."""
self._handlers.append((keyword, response))
return self
def set_structured_default(self, default: dict[str, Any]) -> "MockLLMBackend":
"""Set the default response for structured generation."""
self._structured_default = default
return self
@property
def call_log(self) -> list[dict]:
return self._call_log
def generate(
self,
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 2048,
stop: list[str] | None = None,
) -> str:
full_text = " ".join(m.content for m in messages)
self._call_log.append({
"method": "generate",
"messages": [{"role": m.role, "content": m.content[:200]} for m in messages],
})
for keyword, response in self._handlers:
if keyword.lower() in full_text.lower():
if callable(response):
return response(messages)
return response
# Default: echo the last user message with a generic response
last_user = next(
(m.content for m in reversed(messages) if m.role == "user"),
"no input",
)
return f"[MockLLM] Acknowledged: {last_user[:100]}"
def generate_structured(
self,
messages: list[ChatMessage],
schema: dict[str, Any],
temperature: float = 0.3,
max_tokens: int = 1024,
) -> dict[str, Any]:
self._call_log.append({
"method": "generate_structured",
"schema_keys": list(schema.get("properties", {}).keys()),
})
# Try keyword handlers first — they may return JSON strings or dicts
full_text = " ".join(m.content for m in messages)
for keyword, response in self._handlers:
if keyword.lower() in full_text.lower():
if callable(response):
result = response(messages)
else:
result = response
# If handler returned a string, try to parse as JSON
if isinstance(result, str):
try:
return json.loads(result)
except (json.JSONDecodeError, TypeError):
pass
elif isinstance(result, dict):
return result
# Fall back to structured default
if self._structured_default:
return self._structured_default
# Build a minimal valid response from the schema
props = schema.get("properties", {})
result = {}
for key, prop in props.items():
ptype = prop.get("type", "string")
if ptype == "number":
result[key] = 5.0
elif ptype == "integer":
result[key] = 5
elif ptype == "boolean":
result[key] = True
else:
result[key] = f"mock_{key}"
return result