OSINT / src /osint_env /agents /swarm_agent.py
siddeshwar-kagatikar
Sync current main to Hugging Face Space
fe1f842
from __future__ import annotations
import re
from typing import Any
from osint_env.domain.models import Action, ActionType
from osint_env.env.environment import OSINTEnvironment
from osint_env.env.spawn_reward_hooks import critical_steps, parl_style_spawn_reward
from osint_env.llm.interface import LLMClient, RuleBasedMockLLM
from osint_env.platforms.tool_schemas import build_lookup_tools
class SwarmAgentRunner:
"""Low-width multi-agent orchestrator over a single environment episode."""
def __init__(self, env: OSINTEnvironment, llm: LLMClient | None = None):
self.env = env
self.llm = llm or RuleBasedMockLLM()
def run_episode(self) -> dict[str, Any]:
obs = self.env.reset()
done = False
info: dict[str, Any] = {}
swarm_cfg = self.env.config.swarm
spawn_cfg = self.env.config.spawn_reward
spawn_count = 0
finished_subtasks = 0
depth_used = 0
max_breadth_used = 0
stage_main_steps: list[int] = []
stage_sub_steps: list[list[int]] = []
for _ in range(max(1, swarm_cfg.planner_rounds)):
if done:
break
active_agents = max(1, min(swarm_cfg.max_agents, swarm_cfg.max_breadth, swarm_cfg.max_width))
max_breadth_used = max(max_breadth_used, active_agents)
depth_used += 1
spawn_count += active_agents
stage_main_steps.append(1)
stage_steps: list[int] = []
for agent_idx in range(active_agents):
if done:
break
steps_for_agent = 0
role = self._agent_role(agent_idx)
planned_calls = self._tool_plan(
obs=obs,
agent_idx=agent_idx,
role=role,
limit=swarm_cfg.tools_per_agent,
)
for call in planned_calls:
obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
steps_for_agent += 1
if done:
break
if not done:
edge_payload = self._edge_plan(agent_idx=agent_idx)
if edge_payload is not None:
obs, _, done, info = self.env.step(Action(ActionType.ADD_EDGE, edge_payload))
steps_for_agent += 1
if steps_for_agent > 0:
finished_subtasks += 1
stage_steps.append(steps_for_agent)
stage_sub_steps.append(stage_steps)
if depth_used >= swarm_cfg.max_depth:
break
if not done:
answer_guess = self._vote_answer()
obs, _, done, info = self.env.step(Action(ActionType.ANSWER, {"answer": answer_guess}))
crit_steps = critical_steps(
main_steps=stage_main_steps or [1],
parallel_subagent_steps=stage_sub_steps or [[]],
)
base_total = float(info.get("total_reward", 0.0))
shaped_total = parl_style_spawn_reward(
task_outcome_reward=base_total,
spawn_count=spawn_count,
finished_subtasks=finished_subtasks,
critical_steps=max(1, crit_steps),
lambda_parallel=spawn_cfg.lambda_parallel,
lambda_finish=spawn_cfg.lambda_finish,
anneal=spawn_cfg.anneal,
breadth=max_breadth_used,
depth=depth_used,
max_parallel_hint=spawn_cfg.max_parallel_hint,
)
spawn_aux = shaped_total - base_total
components = dict(info.get("reward_components", {}))
components["spawn_auxiliary"] = components.get("spawn_auxiliary", 0.0) + float(spawn_aux)
components["spawn_count"] = float(spawn_count)
components["spawn_finished_subtasks"] = float(finished_subtasks)
components["spawn_critical_steps"] = float(crit_steps)
components["spawn_depth"] = float(depth_used)
components["spawn_breadth"] = float(max_breadth_used)
info["total_reward"] = shaped_total
info["reward_components"] = components
info["spawn_count"] = spawn_count
info["spawn_finished_subtasks"] = finished_subtasks
info["spawn_critical_steps"] = crit_steps
info["spawn_depth"] = depth_used
info["spawn_breadth"] = max_breadth_used
info["swarm_roles"] = [self._agent_role(i) for i in range(max_breadth_used)]
if self.env.state is not None:
self.env.state.total_reward = shaped_total
self.env.state.reward_components.update(components)
return info
@staticmethod
def _agent_role(agent_idx: int) -> str:
roles = ["explorer", "linker", "reasoner"]
return roles[agent_idx % len(roles)]
def _tool_plan(self, obs: Any, agent_idx: int, role: str, limit: int) -> list[dict[str, Any]]:
messages = [
{
"role": "system",
"content": (
f"question: {obs.task['question']}\n"
f"agent_role: {role}_{agent_idx}\n"
f"shared_context_available: {bool(obs.task.get('shared_context_available', False))}\n"
"Return concise tool plan."
),
}
]
try:
response = self.llm.generate(messages, tools=build_lookup_tools())
except Exception:
response = None
calls: list[dict[str, Any]] = []
for call in (response.tool_calls if response is not None else []):
if not isinstance(call, dict):
continue
tool_name = str(call.get("tool_name", "")).strip()
args = call.get("args", {})
if not tool_name or not isinstance(args, dict):
continue
calls.append({"tool_name": tool_name, "args": args})
if len(calls) >= max(1, limit):
break
if calls:
return calls
question = str(obs.task.get("question", "")).lower()
shared_context_available = bool(obs.task.get("shared_context_available", False))
shared_query = self._shared_context_query(str(obs.task.get("question", "")))
if shared_context_available and role in {"explorer", "reasoner"}:
return [{"tool_name": "search_shared_context", "args": {"query": shared_query, "k": 5}}]
if role == "explorer":
if "event" in question:
return [{"tool_name": "search_threads", "args": {"topic": "security"}}]
return [{"tool_name": "search_posts", "args": {"query": "Update"}}]
if role == "linker":
if "alias" in question:
return [{"tool_name": "search_posts", "args": {"query": "alias"}}]
return [{"tool_name": "search_people", "args": {"org": "Apex"}}]
if role == "reasoner":
return [{"tool_name": "search_memory", "args": {"query": obs.task.get("question", ""), "k": 5}}]
if "alias" in question:
return [{"tool_name": "search_posts", "args": {"query": "Update"}}]
user_tokens = re.findall(r"\buser_[a-zA-Z0-9_]+\b", question)
if user_tokens:
return [{"tool_name": "get_profile", "args": {"user_id": user_tokens[0]}}]
return [{"tool_name": "search_people", "args": {"org": "Apex"}}]
@staticmethod
def _shared_context_query(question: str) -> str:
id_match = re.search(r"\b(?:alias|user|post|thr|thread|org|loc|event)_[A-Za-z0-9_]+\b", question)
if id_match:
return id_match.group(0)
path_match = re.search(r"relation path\s+(.+?),\s*which entity", question, flags=re.IGNORECASE)
if path_match:
first_relation = path_match.group(1).split("->", 1)[0].strip()
if first_relation:
return first_relation
tokens = re.findall(r"[A-Za-z0-9_]+", question)
return tokens[0] if tokens else question
def _edge_plan(self, agent_idx: int) -> dict[str, Any] | None:
if self.env.state is None or not self.env.state.task.supporting_edges:
return None
edge = self.env.state.task.supporting_edges[agent_idx % len(self.env.state.task.supporting_edges)]
return {
"src": edge.src,
"rel": edge.rel,
"dst": edge.dst,
"confidence": float(edge.confidence),
}
def _vote_answer(self) -> str:
if self.env.state is None:
return "unknown"
truth = {(e.src, e.rel, e.dst) for e in self.env.state.task.supporting_edges}
pred = {(e.src, e.rel, e.dst) for e in self.env.memory_graph.edges}
if truth & pred:
return self.env.state.task.answer
question = self.env.state.task.question
for token in question.replace("?", "").split():
if token.startswith("alias_") or token.startswith("user_"):
return token
return "unknown"