Spaces:
Paused
Paused
| from __future__ import annotations | |
| import re | |
| from osint_env.domain.models import Action, ActionType | |
| from osint_env.env.environment import OSINTEnvironment | |
| from osint_env.llm.interface import LLMClient, RuleBasedMockLLM | |
| from osint_env.platforms.tool_schemas import build_lookup_tools | |
| class SingleAgentRunner: | |
| def __init__(self, env: OSINTEnvironment, llm: LLMClient | None = None): | |
| self.env = env | |
| self.llm = llm or RuleBasedMockLLM() | |
| def run_episode(self) -> dict: | |
| obs = self.env.reset() | |
| done = False | |
| info = {} | |
| while not done: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| f"question: {obs.task['question']}\n" | |
| f"shared_context_available: {bool(obs.task.get('shared_context_available', False))}\n" | |
| "Use lookup tools to gather evidence before answering." | |
| ), | |
| } | |
| ] | |
| tools = build_lookup_tools() | |
| try: | |
| llm_resp = self.llm.generate(messages, tools) | |
| planned_calls = llm_resp.tool_calls[:2] | |
| except Exception: | |
| planned_calls = [] | |
| if not planned_calls and bool(obs.task.get("shared_context_available", False)): | |
| planned_calls = [ | |
| { | |
| "tool_name": "search_shared_context", | |
| "args": {"query": self._shared_context_query(obs.task["question"]), "k": 5}, | |
| } | |
| ] | |
| for call in planned_calls: | |
| obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call)) | |
| if done: | |
| break | |
| if done: | |
| break | |
| answer_guess = self._heuristic_answer(obs.task["question"]) | |
| obs, _, done, info = self.env.step(Action(ActionType.ANSWER, {"answer": answer_guess})) | |
| return info | |
| def _heuristic_answer(question: str) -> str: | |
| for token in question.replace("?", "").split(): | |
| if token.startswith("alias_") or token.startswith("user_"): | |
| return token | |
| return "unknown" | |
| def _shared_context_query(question: str) -> str: | |
| id_match = re.search(r"\b(?:alias|user|post|thr|thread|org|loc|event)_[A-Za-z0-9_]+\b", question) | |
| if id_match: | |
| return id_match.group(0) | |
| path_match = re.search(r"relation path\s+(.+?),\s*which entity", question, flags=re.IGNORECASE) | |
| if path_match: | |
| first_relation = path_match.group(1).split("->", 1)[0].strip() | |
| if first_relation: | |
| return first_relation | |
| tokens = re.findall(r"[A-Za-z0-9_]+", question) | |
| return tokens[0] if tokens else question | |