Spaces:
Paused
Paused
| from __future__ import annotations | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any, Protocol | |
| import requests | |
| from requests import RequestException | |
| from osint_env.domain.models import LLMConfig | |
| class LLMResponse: | |
| content: str | |
| tool_calls: list[dict[str, Any]] | |
| class LLMClient(Protocol): | |
| def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse: | |
| ... | |
| class RuleBasedMockLLM: | |
| """Deterministic fallback for local testing without model dependencies.""" | |
| def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse: | |
| question = "" | |
| for m in reversed(messages): | |
| if m.get("role") == "system" and "question" in m.get("content", ""): | |
| question = m["content"] | |
| break | |
| if "alias" in question: | |
| return LLMResponse( | |
| content="Need alias lookup.", | |
| tool_calls=[{"tool_name": "search_posts", "args": {"query": "Update"}}, {"tool_name": "get_profile", "args": {"user_id": "user_0"}}], | |
| ) | |
| return LLMResponse(content="Need profile lookup.", tool_calls=[{"tool_name": "search_people", "args": {"org": "Apex"}}]) | |
| class OllamaLLMClient: | |
| def __init__(self, model: str, base_url: str = "http://127.0.0.1:11434", temperature: float = 0.1, timeout_seconds: int = 240): | |
| self.model = model | |
| self.base_url = base_url.rstrip("/") | |
| self.temperature = float(temperature) | |
| self.timeout_seconds = int(timeout_seconds) | |
| def _extract_tool_calls(content: str) -> list[dict[str, Any]]: | |
| text = str(content or "").strip() | |
| if not text: | |
| return [] | |
| left = text.find("{") | |
| right = text.rfind("}") | |
| if left >= 0 and right > left: | |
| snippet = text[left : right + 1] | |
| try: | |
| parsed = json.loads(snippet) | |
| except json.JSONDecodeError: | |
| parsed = None | |
| if isinstance(parsed, dict) and isinstance(parsed.get("tool_calls"), list): | |
| out: list[dict[str, Any]] = [] | |
| for item in parsed["tool_calls"]: | |
| if isinstance(item, dict) and "tool_name" in item and isinstance(item.get("args", {}), dict): | |
| out.append({"tool_name": str(item["tool_name"]), "args": dict(item.get("args", {}))}) | |
| return out | |
| return [] | |
| def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse: | |
| payload = { | |
| "model": self.model, | |
| "messages": messages, | |
| "stream": False, | |
| "options": { | |
| "temperature": self.temperature, | |
| }, | |
| } | |
| if tools: | |
| payload["tools"] = tools | |
| try: | |
| response = requests.post( | |
| f"{self.base_url}/api/chat", | |
| json=payload, | |
| timeout=self.timeout_seconds, | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| content = str((data.get("message") or {}).get("content", "")) | |
| tool_calls = self._extract_tool_calls(content) | |
| return LLMResponse(content=content, tool_calls=tool_calls) | |
| except (RequestException, ValueError): | |
| # Keep episode execution resilient when local model calls are transiently slow/unavailable. | |
| return LLMResponse(content="", tool_calls=[]) | |
| class OpenAILLMClient: | |
| def __init__( | |
| self, | |
| model: str, | |
| api_key: str, | |
| base_url: str = "https://api.openai.com/v1", | |
| temperature: float = 0.1, | |
| max_tokens: int = 256, | |
| timeout_seconds: int = 240, | |
| ): | |
| from openai import OpenAI | |
| self.model = model | |
| self.temperature = float(temperature) | |
| self.max_tokens = int(max_tokens) | |
| self.client = OpenAI(api_key=api_key, base_url=base_url, timeout=timeout_seconds) | |
| def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse: | |
| kwargs: dict[str, Any] = { | |
| "model": self.model, | |
| "messages": messages, | |
| "temperature": self.temperature, | |
| "max_tokens": self.max_tokens, | |
| } | |
| if tools: | |
| kwargs["tools"] = tools | |
| try: | |
| completion = self.client.chat.completions.create(**kwargs) | |
| message = completion.choices[0].message | |
| content = message.content if isinstance(message.content, str) else "" | |
| tool_calls: list[dict[str, Any]] = [] | |
| for tc in message.tool_calls or []: | |
| try: | |
| args = json.loads(tc.function.arguments or "{}") | |
| except json.JSONDecodeError: | |
| args = {} | |
| tool_calls.append({"tool_name": tc.function.name, "args": args if isinstance(args, dict) else {}}) | |
| return LLMResponse(content=content, tool_calls=tool_calls) | |
| except Exception: | |
| return LLMResponse(content="", tool_calls=[]) | |
| def build_llm_client(config: LLMConfig | None = None) -> LLMClient: | |
| cfg = config or LLMConfig() | |
| provider = str(cfg.provider).strip().lower() | |
| if provider in {"", "mock", "rule", "rule_based"}: | |
| return RuleBasedMockLLM() | |
| if provider == "ollama": | |
| return OllamaLLMClient( | |
| model=cfg.model, | |
| base_url=cfg.ollama_base_url, | |
| temperature=cfg.temperature, | |
| timeout_seconds=cfg.timeout_seconds, | |
| ) | |
| if provider == "openai": | |
| api_key = cfg.openai_api_key or os.getenv(cfg.openai_api_key_env, "") | |
| if not api_key: | |
| raise ValueError( | |
| "OpenAI provider selected but API key is missing. " | |
| f"Set {cfg.openai_api_key_env} or populate openai_api_key in config." | |
| ) | |
| return OpenAILLMClient( | |
| model=cfg.model, | |
| api_key=api_key, | |
| base_url=cfg.openai_base_url, | |
| temperature=cfg.temperature, | |
| max_tokens=cfg.max_tokens, | |
| timeout_seconds=cfg.timeout_seconds, | |
| ) | |
| raise ValueError(f"Unsupported llm provider: {cfg.provider}") | |