File size: 15,946 Bytes
73ecef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cc20b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ecef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cc20b5
73ecef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36d2671
73ecef8
 
 
 
 
 
 
36d2671
73ecef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cc20b5
73ecef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cc20b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f129db1
2cc20b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
"""
LLM Backend — Swappable inference layer.

Supports: HuggingFace Inference Providers, OpenAI, Anthropic, local models,
or any custom backend. Swap by changing one constructor call.

Design: Abstract base class with structured output support.
Inspired by smolagents Model interface + HF Inference Providers API.
"""

from __future__ import annotations

import json
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Message types (OpenAI-compatible chat format)
# ---------------------------------------------------------------------------

@dataclass
class ChatMessage:
    role: str  # "system", "user", "assistant"
    content: str


# ---------------------------------------------------------------------------
# Abstract LLM Backend
# ---------------------------------------------------------------------------

class LLMBackend(ABC):
    """
    Abstract LLM backend. All modules call this — swap the implementation
    to change the underlying model without touching any other code.
    
    Subclasses must implement `generate()` which takes messages and returns
    a string. Optionally implement `generate_structured()` for JSON-schema
    constrained generation (used by the Purpose Function for reliable scoring).
    """

    @staticmethod
    def _strip_thinking(text: str) -> str:
        """
        Strip <think>...</think> tags from model output.
        
        Many reasoning models (Qwen3, DeepSeek-R1, etc.) wrap their
        chain-of-thought in <think> tags. We keep only the final answer.
        """
        import re
        # Remove <think>...</think> blocks (greedy, handles multiline)
        cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
        # Also handle unclosed <think> tags (model cut off mid-thought)
        cleaned = re.sub(r'<think>.*$', '', cleaned, flags=re.DOTALL)
        return cleaned.strip()

    @abstractmethod
    def generate(
        self,
        messages: list[ChatMessage],
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stop: list[str] | None = None,
    ) -> str:
        """Generate a text completion from chat messages."""
        ...

    def generate_structured(
        self,
        messages: list[ChatMessage],
        schema: dict[str, Any],
        temperature: float = 0.3,
        max_tokens: int = 1024,
    ) -> dict[str, Any]:
        """
        Generate with JSON schema constraint.
        
        Default implementation: append schema instruction to last message
        and parse JSON from response. Override for native structured output.
        """
        schema_instruction = (
            f"\n\nYou MUST respond with valid JSON matching this schema:\n"
            f"```json\n{json.dumps(schema, indent=2)}\n```\n"
            f"Respond ONLY with the JSON object, no other text."
        )
        augmented = list(messages)
        last = augmented[-1]
        augmented[-1] = ChatMessage(
            role=last.role, content=last.content + schema_instruction
        )
        raw = self.generate(augmented, temperature=temperature, max_tokens=max_tokens)

        # Extract JSON from response (handle markdown code blocks)
        text = raw.strip()
        if text.startswith("```"):
            lines = text.split("\n")
            # Remove first and last ``` lines
            json_lines = []
            inside = False
            for line in lines:
                if line.strip().startswith("```") and not inside:
                    inside = True
                    continue
                elif line.strip() == "```" and inside:
                    break
                elif inside:
                    json_lines.append(line)
            text = "\n".join(json_lines)

        return json.loads(text)


# ---------------------------------------------------------------------------
# HuggingFace Inference Provider Backend
# ---------------------------------------------------------------------------

class HFInferenceBackend(LLMBackend):
    """
    Uses huggingface_hub InferenceClient for HF Inference Providers.
    
    Supports: Cerebras, Novita, Fireworks, Together, SambaNova, etc.
    Models: Qwen, Llama, Mistral, DeepSeek — anything on HF Hub.
    
    Example:
        backend = HFInferenceBackend(
            model_id="Qwen/Qwen3-32B",
            provider="cerebras",
        )
    """

    def __init__(
        self,
        model_id: str = "Qwen/Qwen3-32B",
        provider: str = "auto",
        api_key: str | None = None,
    ):
        from huggingface_hub import InferenceClient

        self.model_id = model_id
        self.provider = provider
        self.client = InferenceClient(
            provider=provider,
            api_key=api_key or os.environ.get("HF_TOKEN"),
        )

    def generate(
        self,
        messages: list[ChatMessage],
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stop: list[str] | None = None,
    ) -> str:
        msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
        response = self.client.chat_completion(
            model=self.model_id,
            messages=msg_dicts,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop or [],
        )
        return self._strip_thinking(response.choices[0].message.content or "")

    def generate_structured(
        self,
        messages: list[ChatMessage],
        schema: dict[str, Any],
        temperature: float = 0.3,
        max_tokens: int = 1024,
    ) -> dict[str, Any]:
        msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
        response = self.client.chat_completion(
            model=self.model_id,
            messages=msg_dicts,
            temperature=temperature,
            max_tokens=max_tokens,
            response_format={
                "type": "json_schema",
                "json_schema": {"schema": schema},
            },
        )
        return json.loads(response.choices[0].message.content)


# ---------------------------------------------------------------------------
# OpenAI-Compatible Backend (OpenAI, Azure, vLLM, Ollama, LiteLLM)
# ---------------------------------------------------------------------------

class OpenAICompatibleBackend(LLMBackend):
    """
    Works with any OpenAI-compatible API endpoint.
    
    Examples:
        # OpenAI
        backend = OpenAICompatibleBackend(model="gpt-4o")
        
        # Local Ollama
        backend = OpenAICompatibleBackend(
            model="llama3.2",
            base_url="http://localhost:11434/v1",
            api_key="ollama",
        )
        
        # vLLM server
        backend = OpenAICompatibleBackend(
            model="meta-llama/Llama-3.2-3B-Instruct",
            base_url="http://localhost:8000/v1",
            api_key="token-placeholder",
        )
        
        # HF Inference via OpenAI SDK (for structured output with .parse())
        backend = OpenAICompatibleBackend(
            model="Qwen/Qwen3-32B",
            base_url="https://router.huggingface.co/cerebras/v1",
            api_key=os.environ["HF_TOKEN"],
        )
    """

    def __init__(
        self,
        model: str = "gpt-4o",
        base_url: str | None = None,
        api_key: str | None = None,
        timeout: float = 60.0,
    ):
        from openai import OpenAI

        self.model = model
        self.client = OpenAI(
            base_url=base_url,
            api_key=api_key or os.environ.get("OPENAI_API_KEY"),
            timeout=timeout,
        )

    def generate(
        self,
        messages: list[ChatMessage],
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stop: list[str] | None = None,
    ) -> str:
        msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
        response = self.client.chat.completions.create(
            model=self.model,
            messages=msg_dicts,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
        )
        return self._strip_thinking(response.choices[0].message.content or "")

    def generate_structured(
        self,
        messages: list[ChatMessage],
        schema: dict[str, Any],
        temperature: float = 0.3,
        max_tokens: int = 1024,
    ) -> dict[str, Any]:
        msg_dicts = [{"role": m.role, "content": m.content} for m in messages]
        response = self.client.chat.completions.create(
            model=self.model,
            messages=msg_dicts,
            temperature=temperature,
            max_tokens=max_tokens,
            response_format={
                "type": "json_schema",
                "json_schema": {"name": "purpose_score", "schema": schema},
            },
        )
        return json.loads(response.choices[0].message.content)


# ---------------------------------------------------------------------------
# Mock Backend (for testing without API calls)
# ---------------------------------------------------------------------------

class MockLLMBackend(LLMBackend):
    """
    Deterministic mock backend for testing the framework without LLM calls.
    
    Returns canned responses based on keywords in the prompt, or a default.
    You can register custom response handlers.
    """

    def __init__(self):
        self._handlers: list[tuple[str, str | callable]] = []
        self._structured_default: dict[str, Any] = {}
        self._call_log: list[dict] = []

    def register_handler(
        self, keyword: str, response: str | callable
    ) -> "MockLLMBackend":
        """Add a keyword-matched response handler. Checked in order."""
        self._handlers.append((keyword, response))
        return self

    def set_structured_default(self, default: dict[str, Any]) -> "MockLLMBackend":
        """Set the default response for structured generation."""
        self._structured_default = default
        return self

    @property
    def call_log(self) -> list[dict]:
        return self._call_log

    def generate(
        self,
        messages: list[ChatMessage],
        temperature: float = 0.7,
        max_tokens: int = 2048,
        stop: list[str] | None = None,
    ) -> str:
        full_text = " ".join(m.content for m in messages)
        self._call_log.append({
            "method": "generate",
            "messages": [{"role": m.role, "content": m.content[:200]} for m in messages],
        })

        for keyword, response in self._handlers:
            if keyword.lower() in full_text.lower():
                if callable(response):
                    return response(messages)
                return response

        # Default: echo the last user message with a generic response
        last_user = next(
            (m.content for m in reversed(messages) if m.role == "user"),
            "no input",
        )
        return f"[MockLLM] Acknowledged: {last_user[:100]}"

    def generate_structured(
        self,
        messages: list[ChatMessage],
        schema: dict[str, Any],
        temperature: float = 0.3,
        max_tokens: int = 1024,
    ) -> dict[str, Any]:
        self._call_log.append({
            "method": "generate_structured",
            "schema_keys": list(schema.get("properties", {}).keys()),
        })
        # Try keyword handlers first — they may return JSON strings or dicts
        full_text = " ".join(m.content for m in messages)
        for keyword, response in self._handlers:
            if keyword.lower() in full_text.lower():
                if callable(response):
                    result = response(messages)
                else:
                    result = response
                # If handler returned a string, try to parse as JSON
                if isinstance(result, str):
                    try:
                        return json.loads(result)
                    except (json.JSONDecodeError, TypeError):
                        pass
                elif isinstance(result, dict):
                    return result

        # Fall back to structured default
        if self._structured_default:
            return self._structured_default
        # Build a minimal valid response from the schema
        props = schema.get("properties", {})
        result = {}
        for key, prop in props.items():
            ptype = prop.get("type", "string")
            if ptype == "number":
                result[key] = 5.0
            elif ptype == "integer":
                result[key] = 5
            elif ptype == "boolean":
                result[key] = True
            else:
                result[key] = f"mock_{key}"
        return result


# ---------------------------------------------------------------------------
# Multi-Provider Router
# ---------------------------------------------------------------------------

# Provider → (base_url, env_var_for_key)
_PROVIDER_MAP = {
    "groq":      ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
    "openai":    ("https://api.openai.com/v1", "OPENAI_API_KEY"),
    "together":  ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
    "fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
    "deepseek":  ("https://api.deepseek.com/v1", "DEEPSEEK_API_KEY"),
    "mistral":   ("https://api.mistral.ai/v1", "MISTRAL_API_KEY"),
    "cerebras":  ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
    "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
}


def resolve_backend(spec: str, api_key: str | None = None) -> LLMBackend:
    """
    Resolve a 'provider:model' string into an LLMBackend.

    Supports every major inference provider via OpenAI-compatible APIs,
    plus Ollama for local models and HF for HuggingFace Inference.

    Examples:
        resolve_backend("groq:llama-3.3-70b-versatile")
        resolve_backend("openai:gpt-4o")
        resolve_backend("ollama:qwen3:1.7b")
        resolve_backend("hf:Qwen/Qwen3-32B")
        resolve_backend("together:meta-llama/Llama-3.3-70B-Instruct-Turbo")
        resolve_backend("deepseek:deepseek-chat")

    For local models without a provider prefix:
        resolve_backend("qwen3:1.7b")        # auto-detects Ollama
        resolve_backend("gpt-4o")             # auto-detects OpenAI
        resolve_backend("Qwen/Qwen3-32B")    # auto-detects HF
    """
    if ":" in spec:
        parts = spec.split(":", 1)
        provider = parts[0].lower()

        if provider == "ollama":
            from purpose_agent.slm_backends import OllamaBackend
            return OllamaBackend(model=parts[1])

        if provider == "hf":
            return HFInferenceBackend(model_id=parts[1], api_key=api_key)

        if provider in _PROVIDER_MAP:
            base_url, env_var = _PROVIDER_MAP[provider]
            key = api_key or os.environ.get(env_var, "")
            if not key:
                raise ValueError(
                    f"No API key for {provider}. Set {env_var} environment variable "
                    f"or pass api_key= parameter."
                )
            return OpenAICompatibleBackend(
                model=parts[1], base_url=base_url, api_key=key,
            )

        # Not a known provider — might be Ollama model like "qwen3:1.7b"
        from purpose_agent.slm_backends import OllamaBackend
        return OllamaBackend(model=spec)

    # No colon — auto-detect
    if spec.startswith("gpt-") or spec.startswith("o1") or spec.startswith("o3"):
        key = api_key or os.environ.get("OPENAI_API_KEY", "")
        return OpenAICompatibleBackend(model=spec, api_key=key)

    if "/" in spec:
        return HFInferenceBackend(model_id=spec, api_key=api_key)

    from purpose_agent.slm_backends import OllamaBackend
    return OllamaBackend(model=spec)