Rohan03 commited on
Commit
2cc20b5
·
verified ·
1 Parent(s): ec4fcc6

Track 1: purpose_agent/llm_backend.py

Browse files
Files changed (1) hide show
  1. purpose_agent/llm_backend.py +92 -2
purpose_agent/llm_backend.py CHANGED
@@ -44,6 +44,21 @@ class LLMBackend(ABC):
44
  constrained generation (used by the Purpose Function for reliable scoring).
45
  """
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @abstractmethod
48
  def generate(
49
  self,
@@ -148,7 +163,7 @@ class HFInferenceBackend(LLMBackend):
148
  max_tokens=max_tokens,
149
  stop=stop or [],
150
  )
151
- return response.choices[0].message.content
152
 
153
  def generate_structured(
154
  self,
@@ -234,7 +249,7 @@ class OpenAICompatibleBackend(LLMBackend):
234
  max_tokens=max_tokens,
235
  stop=stop,
236
  )
237
- return response.choices[0].message.content
238
 
239
  def generate_structured(
240
  self,
@@ -361,3 +376,78 @@ class MockLLMBackend(LLMBackend):
361
  else:
362
  result[key] = f"mock_{key}"
363
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  constrained generation (used by the Purpose Function for reliable scoring).
45
  """
46
 
47
+ @staticmethod
48
+ def _strip_thinking(text: str) -> str:
49
+ """
50
+ Strip <think>...</think> tags from model output.
51
+
52
+ Many reasoning models (Qwen3, DeepSeek-R1, etc.) wrap their
53
+ chain-of-thought in <think> tags. We keep only the final answer.
54
+ """
55
+ import re
56
+ # Remove <think>...</think> blocks (greedy, handles multiline)
57
+ cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
58
+ # Also handle unclosed <think> tags (model cut off mid-thought)
59
+ cleaned = re.sub(r'<think>.*$', '', cleaned, flags=re.DOTALL)
60
+ return cleaned.strip()
61
+
62
  @abstractmethod
63
  def generate(
64
  self,
 
163
  max_tokens=max_tokens,
164
  stop=stop or [],
165
  )
166
+ return self._strip_thinking(response.choices[0].message.content or "")
167
 
168
  def generate_structured(
169
  self,
 
249
  max_tokens=max_tokens,
250
  stop=stop,
251
  )
252
+ return self._strip_thinking(response.choices[0].message.content or "")
253
 
254
  def generate_structured(
255
  self,
 
376
  else:
377
  result[key] = f"mock_{key}"
378
  return result
379
+
380
+
381
+ # ---------------------------------------------------------------------------
382
+ # Multi-Provider Router
383
+ # ---------------------------------------------------------------------------
384
+
385
+ # Provider → (base_url, env_var_for_key)
386
+ _PROVIDER_MAP = {
387
+ "groq": ("https://api.groq.com/openai/v1", "GROQ_API_KEY"),
388
+ "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"),
389
+ "together": ("https://api.together.xyz/v1", "TOGETHER_API_KEY"),
390
+ "fireworks": ("https://api.fireworks.ai/inference/v1", "FIREWORKS_API_KEY"),
391
+ "deepseek": ("https://api.deepseek.com/v1", "DEEPSEEK_API_KEY"),
392
+ "mistral": ("https://api.mistral.ai/v1", "MISTRAL_API_KEY"),
393
+ "cerebras": ("https://api.cerebras.ai/v1", "CEREBRAS_API_KEY"),
394
+ }
395
+
396
+
397
+ def resolve_backend(spec: str, api_key: str | None = None) -> LLMBackend:
398
+ """
399
+ Resolve a 'provider:model' string into an LLMBackend.
400
+
401
+ Supports every major inference provider via OpenAI-compatible APIs,
402
+ plus Ollama for local models and HF for HuggingFace Inference.
403
+
404
+ Examples:
405
+ resolve_backend("groq:llama-3.3-70b-versatile")
406
+ resolve_backend("openai:gpt-4o")
407
+ resolve_backend("ollama:qwen3:1.7b")
408
+ resolve_backend("hf:Qwen/Qwen3-32B")
409
+ resolve_backend("together:meta-llama/Llama-3.3-70B-Instruct-Turbo")
410
+ resolve_backend("deepseek:deepseek-chat")
411
+
412
+ For local models without a provider prefix:
413
+ resolve_backend("qwen3:1.7b") # auto-detects Ollama
414
+ resolve_backend("gpt-4o") # auto-detects OpenAI
415
+ resolve_backend("Qwen/Qwen3-32B") # auto-detects HF
416
+ """
417
+ if ":" in spec:
418
+ parts = spec.split(":", 1)
419
+ provider = parts[0].lower()
420
+
421
+ if provider == "ollama":
422
+ from purpose_agent.slm_backends import OllamaBackend
423
+ return OllamaBackend(model=parts[1])
424
+
425
+ if provider == "hf":
426
+ return HFInferenceBackend(model_id=parts[1], api_key=api_key)
427
+
428
+ if provider in _PROVIDER_MAP:
429
+ base_url, env_var = _PROVIDER_MAP[provider]
430
+ key = api_key or os.environ.get(env_var, "")
431
+ if not key:
432
+ raise ValueError(
433
+ f"No API key for {provider}. Set {env_var} environment variable "
434
+ f"or pass api_key= parameter."
435
+ )
436
+ return OpenAICompatibleBackend(
437
+ model=parts[1], base_url=base_url, api_key=key,
438
+ )
439
+
440
+ # Not a known provider — might be Ollama model like "qwen3:1.7b"
441
+ from purpose_agent.slm_backends import OllamaBackend
442
+ return OllamaBackend(model=spec)
443
+
444
+ # No colon — auto-detect
445
+ if spec.startswith("gpt-") or spec.startswith("o1") or spec.startswith("o3"):
446
+ key = api_key or os.environ.get("OPENAI_API_KEY", "")
447
+ return OpenAICompatibleBackend(model=spec, api_key=key)
448
+
449
+ if "/" in spec:
450
+ return HFInferenceBackend(model_id=spec, api_key=api_key)
451
+
452
+ from purpose_agent.slm_backends import OllamaBackend
453
+ return OllamaBackend(model=spec)