Spaces:
Paused
Paused
File size: 6,298 Bytes
db4fa53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | from __future__ import annotations
import json
import os
from dataclasses import dataclass
from typing import Any, Protocol
import requests
from requests import RequestException
from osint_env.domain.models import LLMConfig
@dataclass(slots=True)
class LLMResponse:
content: str
tool_calls: list[dict[str, Any]]
class LLMClient(Protocol):
def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse:
...
class RuleBasedMockLLM:
"""Deterministic fallback for local testing without model dependencies."""
def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse:
question = ""
for m in reversed(messages):
if m.get("role") == "system" and "question" in m.get("content", ""):
question = m["content"]
break
if "alias" in question:
return LLMResponse(
content="Need alias lookup.",
tool_calls=[{"tool_name": "search_posts", "args": {"query": "Update"}}, {"tool_name": "get_profile", "args": {"user_id": "user_0"}}],
)
return LLMResponse(content="Need profile lookup.", tool_calls=[{"tool_name": "search_people", "args": {"org": "Apex"}}])
class OllamaLLMClient:
def __init__(self, model: str, base_url: str = "http://127.0.0.1:11434", temperature: float = 0.1, timeout_seconds: int = 240):
self.model = model
self.base_url = base_url.rstrip("/")
self.temperature = float(temperature)
self.timeout_seconds = int(timeout_seconds)
@staticmethod
def _extract_tool_calls(content: str) -> list[dict[str, Any]]:
text = str(content or "").strip()
if not text:
return []
left = text.find("{")
right = text.rfind("}")
if left >= 0 and right > left:
snippet = text[left : right + 1]
try:
parsed = json.loads(snippet)
except json.JSONDecodeError:
parsed = None
if isinstance(parsed, dict) and isinstance(parsed.get("tool_calls"), list):
out: list[dict[str, Any]] = []
for item in parsed["tool_calls"]:
if isinstance(item, dict) and "tool_name" in item and isinstance(item.get("args", {}), dict):
out.append({"tool_name": str(item["tool_name"]), "args": dict(item.get("args", {}))})
return out
return []
def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse:
payload = {
"model": self.model,
"messages": messages,
"stream": False,
"options": {
"temperature": self.temperature,
},
}
if tools:
payload["tools"] = tools
try:
response = requests.post(
f"{self.base_url}/api/chat",
json=payload,
timeout=self.timeout_seconds,
)
response.raise_for_status()
data = response.json()
content = str((data.get("message") or {}).get("content", ""))
tool_calls = self._extract_tool_calls(content)
return LLMResponse(content=content, tool_calls=tool_calls)
except (RequestException, ValueError):
# Keep episode execution resilient when local model calls are transiently slow/unavailable.
return LLMResponse(content="", tool_calls=[])
class OpenAILLMClient:
def __init__(
self,
model: str,
api_key: str,
base_url: str = "https://api.openai.com/v1",
temperature: float = 0.1,
max_tokens: int = 256,
timeout_seconds: int = 240,
):
from openai import OpenAI
self.model = model
self.temperature = float(temperature)
self.max_tokens = int(max_tokens)
self.client = OpenAI(api_key=api_key, base_url=base_url, timeout=timeout_seconds)
def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse:
kwargs: dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}
if tools:
kwargs["tools"] = tools
try:
completion = self.client.chat.completions.create(**kwargs)
message = completion.choices[0].message
content = message.content if isinstance(message.content, str) else ""
tool_calls: list[dict[str, Any]] = []
for tc in message.tool_calls or []:
try:
args = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
args = {}
tool_calls.append({"tool_name": tc.function.name, "args": args if isinstance(args, dict) else {}})
return LLMResponse(content=content, tool_calls=tool_calls)
except Exception:
return LLMResponse(content="", tool_calls=[])
def build_llm_client(config: LLMConfig | None = None) -> LLMClient:
cfg = config or LLMConfig()
provider = str(cfg.provider).strip().lower()
if provider in {"", "mock", "rule", "rule_based"}:
return RuleBasedMockLLM()
if provider == "ollama":
return OllamaLLMClient(
model=cfg.model,
base_url=cfg.ollama_base_url,
temperature=cfg.temperature,
timeout_seconds=cfg.timeout_seconds,
)
if provider == "openai":
api_key = cfg.openai_api_key or os.getenv(cfg.openai_api_key_env, "")
if not api_key:
raise ValueError(
"OpenAI provider selected but API key is missing. "
f"Set {cfg.openai_api_key_env} or populate openai_api_key in config."
)
return OpenAILLMClient(
model=cfg.model,
api_key=api_key,
base_url=cfg.openai_base_url,
temperature=cfg.temperature,
max_tokens=cfg.max_tokens,
timeout_seconds=cfg.timeout_seconds,
)
raise ValueError(f"Unsupported llm provider: {cfg.provider}")
|