ritishshrirao commited on
Commit
d6fbf54
·
1 Parent(s): ce675d4

Update LLM interface, add multi agent data generation,

Browse files
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  *.pyc
2
  blueprint.txt
3
- *.egg-info
 
 
1
  *.pyc
2
  blueprint.txt
3
+ *.egg-info
4
+ artifacts/*
README.md CHANGED
@@ -14,6 +14,10 @@ The environment models a realistic workflow for information discovery and linkin
14
  4. Let agents call tools, add graph edges, and submit answers.
15
  5. Score episodes using a composite reward that combines correctness, retrieval utility, graph quality, and efficiency.
16
 
 
 
 
 
17
  ## 2. Current Capabilities
18
 
19
  - Single-agent baseline runner.
@@ -42,6 +46,46 @@ Example:
42
 
43
  The project requires Python 3.10+.
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ## 4. Repository Layout
46
 
47
  src/osint_env/
@@ -71,6 +115,7 @@ This file includes:
71
  - swarm limits,
72
  - spawn reward shaping hyperparameters,
73
  - seeding defaults,
 
74
  - runtime output paths.
75
 
76
  Default swarm settings are intentionally conservative:
@@ -105,6 +150,10 @@ All commands accept:
105
  - --config for shared config path (default: config/shared_config.json)
106
  - --seed-file for seeded graph/task input JSON
107
  - --agent-mode with values: config, single, swarm
 
 
 
 
108
 
109
  Main commands:
110
 
@@ -132,6 +181,14 @@ Main commands:
132
 
133
  osint-env viz --with-demo --output artifacts/osint_explorer.html
134
 
 
 
 
 
 
 
 
 
135
  ## 8. Multi-Agent Swarm Design
136
 
137
  Swarm orchestration is implemented in src/osint_env/agents/swarm_agent.py.
 
14
  4. Let agents call tools, add graph edges, and submit answers.
15
  5. Score episodes using a composite reward that combines correctness, retrieval utility, graph quality, and efficiency.
16
 
17
+ The tool layer also supports semantic-memory retrieval over prior observations:
18
+
19
+ - search_memory(query, k): vector-style retrieval over accumulated tool outputs.
20
+
21
  ## 2. Current Capabilities
22
 
23
  - Single-agent baseline runner.
 
46
 
47
  The project requires Python 3.10+.
48
 
49
+ ## 3.1 LLM Backends
50
+
51
+ The environment supports three LLM providers:
52
+
53
+ - mock: deterministic fallback for reproducible local tests.
54
+ - ollama: local model inference (recommended for offline development).
55
+ - openai: remote API provider using an API key.
56
+
57
+ The provider is configured through config/shared_config.json (llm block) and can be overridden from CLI.
58
+
59
+ ### Local Ollama Setup (Qwen 3 2B)
60
+
61
+ 1. Install Ollama.
62
+ 2. Start Ollama service.
63
+ 3. Pull the model:
64
+
65
+ ollama pull qwen3:2b
66
+
67
+ If your local Ollama registry does not expose `qwen3:2b`, use:
68
+
69
+ ollama pull qwen3:1.7b
70
+ ollama cp qwen3:1.7b qwen3:2b
71
+
72
+ 4. Run demo in swarm mode with local model:
73
+
74
+ osint-env demo --agent-mode swarm --llm-provider ollama --llm-model qwen3:2b
75
+
76
+ ### OpenAI Setup
77
+
78
+ 1. Export API key:
79
+
80
+ export OPENAI_API_KEY="your_key_here"
81
+
82
+ 2. Run with OpenAI backend:
83
+
84
+ osint-env eval --episodes 10 --llm-provider openai --llm-model gpt-4o-mini
85
+
86
+ You can also provide the key via config/shared_config.json using llm.openai_api_key,
87
+ or specify a custom environment variable name via llm.openai_api_key_env.
88
+
89
  ## 4. Repository Layout
90
 
91
  src/osint_env/
 
115
  - swarm limits,
116
  - spawn reward shaping hyperparameters,
117
  - seeding defaults,
118
+ - llm backend defaults,
119
  - runtime output paths.
120
 
121
  Default swarm settings are intentionally conservative:
 
150
  - --config for shared config path (default: config/shared_config.json)
151
  - --seed-file for seeded graph/task input JSON
152
  - --agent-mode with values: config, single, swarm
153
+ - --llm-provider with values: config, mock, ollama, openai
154
+ - --llm-model to override configured model
155
+ - --ollama-base-url to override local Ollama endpoint
156
+ - --openai-api-key or --openai-api-key-env for OpenAI authentication
157
 
158
  Main commands:
159
 
 
181
 
182
  osint-env viz --with-demo --output artifacts/osint_explorer.html
183
 
184
+ 7. Benchmark with local Qwen model:
185
+
186
+ osint-env benchmark --episodes 20 --agent-mode swarm --llm-provider ollama --llm-model qwen3:2b --name qwen3_swarm
187
+
188
+ 8. Fast local smoke benchmark:
189
+
190
+ osint-env benchmark --episodes 1 --agent-mode swarm --llm-provider ollama --llm-model qwen3:2b --seed-file config/seed_ollama_smoke.json --name ollama_qwen_smoke
191
+
192
  ## 8. Multi-Agent Swarm Design
193
 
194
  Swarm orchestration is implemented in src/osint_env/agents/swarm_agent.py.
config/seed_example.json CHANGED
@@ -18,6 +18,7 @@
18
  }
19
  }
20
  ],
 
21
  "seeded_edges": [
22
  {
23
  "src": "alias_seed_001",
 
18
  }
19
  }
20
  ],
21
+ "_note": "Use with --seed-file. LLM provider and API keys are configured in config/shared_config.json or CLI flags.",
22
  "seeded_edges": [
23
  {
24
  "src": "alias_seed_001",
config/seed_ollama_smoke.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seeding": {
3
+ "seeded_nodes": [
4
+ {
5
+ "node_id": "alias_smoke_001",
6
+ "node_type": "alias",
7
+ "attrs": {
8
+ "handle": "@smoke_alias"
9
+ }
10
+ },
11
+ {
12
+ "node_id": "user_smoke_001",
13
+ "node_type": "user",
14
+ "attrs": {
15
+ "name": "Smoke User",
16
+ "org": "Apex Dynamics",
17
+ "location": "Bengaluru"
18
+ }
19
+ }
20
+ ],
21
+ "seeded_edges": [
22
+ {
23
+ "src": "alias_smoke_001",
24
+ "rel": "alias_of",
25
+ "dst": "user_smoke_001",
26
+ "confidence": 1.0
27
+ }
28
+ ],
29
+ "seeded_questions": [
30
+ {
31
+ "task_type": "identity_resolution",
32
+ "question": "Which canonical user owns alias alias_smoke_001?",
33
+ "answer": "user_smoke_001",
34
+ "supporting_edges": [
35
+ {
36
+ "src": "alias_smoke_001",
37
+ "rel": "alias_of",
38
+ "dst": "user_smoke_001"
39
+ }
40
+ ],
41
+ "metadata": {
42
+ "source": "ollama_smoke"
43
+ }
44
+ }
45
+ ],
46
+ "llm_generate_remaining_graph": false,
47
+ "llm_generate_remaining_tasks": false,
48
+ "llm_generated_edge_budget": 0,
49
+ "llm_generated_task_budget": 0
50
+ }
51
+ }
config/shared_config.json CHANGED
@@ -29,7 +29,22 @@
29
  "llm_generate_remaining_graph": true,
30
  "llm_generate_remaining_tasks": true,
31
  "llm_generated_edge_budget": 6,
32
- "llm_generated_task_budget": 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  },
34
  "runtime": {
35
  "default_episodes": 20,
 
29
  "llm_generate_remaining_graph": true,
30
  "llm_generate_remaining_tasks": true,
31
  "llm_generated_edge_budget": 6,
32
+ "llm_generated_task_budget": 8,
33
+ "llm_generation_parallel": true,
34
+ "llm_generation_workers": 3,
35
+ "llm_generation_retries": 2,
36
+ "allow_template_fallback_on_llm_failure": false
37
+ },
38
+ "llm": {
39
+ "provider": "ollama",
40
+ "model": "qwen3:2b",
41
+ "temperature": 0.1,
42
+ "max_tokens": 256,
43
+ "timeout_seconds": 240,
44
+ "ollama_base_url": "http://127.0.0.1:11434",
45
+ "openai_base_url": "https://api.openai.com/v1",
46
+ "openai_api_key_env": "OPENAI_API_KEY",
47
+ "openai_api_key": ""
48
  },
49
  "runtime": {
50
  "default_episodes": 20,
pyproject.toml CHANGED
@@ -4,7 +4,11 @@ version = "0.1.0"
4
  description = "OSINT-style multi-platform information ecosystem environment for LLM agents."
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
- dependencies = ["openenv"]
 
 
 
 
8
 
9
  [project.scripts]
10
  osint-env = "osint_env.cli:main"
 
4
  description = "OSINT-style multi-platform information ecosystem environment for LLM agents."
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
+ dependencies = [
8
+ "openenv",
9
+ "openai>=1.40.0",
10
+ "requests>=2.31.0",
11
+ ]
12
 
13
  [project.scripts]
14
  osint-env = "osint_env.cli:main"
src/osint_env/agents/single_agent.py CHANGED
@@ -17,8 +17,13 @@ class SingleAgentRunner:
17
  while not done:
18
  messages = [{"role": "system", "content": f"question: {obs.task['question']}"}]
19
  tools = []
20
- llm_resp = self.llm.generate(messages, tools)
21
- for call in llm_resp.tool_calls[:2]:
 
 
 
 
 
22
  obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
23
  if done:
24
  break
 
17
  while not done:
18
  messages = [{"role": "system", "content": f"question: {obs.task['question']}"}]
19
  tools = []
20
+ try:
21
+ llm_resp = self.llm.generate(messages, tools)
22
+ planned_calls = llm_resp.tool_calls[:2]
23
+ except Exception:
24
+ planned_calls = []
25
+
26
+ for call in planned_calls:
27
  obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
28
  if done:
29
  break
src/osint_env/agents/swarm_agent.py CHANGED
@@ -48,7 +48,13 @@ class SwarmAgentRunner:
48
  break
49
 
50
  steps_for_agent = 0
51
- planned_calls = self._tool_plan(obs=obs, agent_idx=agent_idx, limit=swarm_cfg.tools_per_agent)
 
 
 
 
 
 
52
  for call in planned_calls:
53
  obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
54
  steps_for_agent += 1
@@ -109,6 +115,7 @@ class SwarmAgentRunner:
109
  info["spawn_critical_steps"] = crit_steps
110
  info["spawn_depth"] = depth_used
111
  info["spawn_breadth"] = max_breadth_used
 
112
 
113
  if self.env.state is not None:
114
  self.env.state.total_reward = shaped_total
@@ -116,21 +123,29 @@ class SwarmAgentRunner:
116
 
117
  return info
118
 
119
- def _tool_plan(self, obs: Any, agent_idx: int, limit: int) -> list[dict[str, Any]]:
 
 
 
 
 
120
  messages = [
121
  {
122
  "role": "system",
123
  "content": (
124
  f"question: {obs.task['question']}\n"
125
- f"agent_role: swarm_worker_{agent_idx}\n"
126
  "Return concise tool plan."
127
  ),
128
  }
129
  ]
130
- response = self.llm.generate(messages, tools=[])
 
 
 
131
 
132
  calls: list[dict[str, Any]] = []
133
- for call in response.tool_calls:
134
  if not isinstance(call, dict):
135
  continue
136
  tool_name = str(call.get("tool_name", "")).strip()
@@ -145,6 +160,19 @@ class SwarmAgentRunner:
145
  return calls
146
 
147
  question = str(obs.task.get("question", "")).lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  if "alias" in question:
149
  return [{"tool_name": "search_posts", "args": {"query": "Update"}}]
150
 
 
48
  break
49
 
50
  steps_for_agent = 0
51
+ role = self._agent_role(agent_idx)
52
+ planned_calls = self._tool_plan(
53
+ obs=obs,
54
+ agent_idx=agent_idx,
55
+ role=role,
56
+ limit=swarm_cfg.tools_per_agent,
57
+ )
58
  for call in planned_calls:
59
  obs, _, done, info = self.env.step(Action(ActionType.CALL_TOOL, call))
60
  steps_for_agent += 1
 
115
  info["spawn_critical_steps"] = crit_steps
116
  info["spawn_depth"] = depth_used
117
  info["spawn_breadth"] = max_breadth_used
118
+ info["swarm_roles"] = [self._agent_role(i) for i in range(max_breadth_used)]
119
 
120
  if self.env.state is not None:
121
  self.env.state.total_reward = shaped_total
 
123
 
124
  return info
125
 
126
+ @staticmethod
127
+ def _agent_role(agent_idx: int) -> str:
128
+ roles = ["explorer", "linker", "reasoner"]
129
+ return roles[agent_idx % len(roles)]
130
+
131
+ def _tool_plan(self, obs: Any, agent_idx: int, role: str, limit: int) -> list[dict[str, Any]]:
132
  messages = [
133
  {
134
  "role": "system",
135
  "content": (
136
  f"question: {obs.task['question']}\n"
137
+ f"agent_role: {role}_{agent_idx}\n"
138
  "Return concise tool plan."
139
  ),
140
  }
141
  ]
142
+ try:
143
+ response = self.llm.generate(messages, tools=[])
144
+ except Exception:
145
+ response = None
146
 
147
  calls: list[dict[str, Any]] = []
148
+ for call in (response.tool_calls if response is not None else []):
149
  if not isinstance(call, dict):
150
  continue
151
  tool_name = str(call.get("tool_name", "")).strip()
 
160
  return calls
161
 
162
  question = str(obs.task.get("question", "")).lower()
163
+ if role == "explorer":
164
+ if "event" in question:
165
+ return [{"tool_name": "search_threads", "args": {"topic": "security"}}]
166
+ return [{"tool_name": "search_posts", "args": {"query": "Update"}}]
167
+
168
+ if role == "linker":
169
+ if "alias" in question:
170
+ return [{"tool_name": "search_posts", "args": {"query": "alias"}}]
171
+ return [{"tool_name": "search_people", "args": {"org": "Apex"}}]
172
+
173
+ if role == "reasoner":
174
+ return [{"tool_name": "search_memory", "args": {"query": obs.task.get("question", ""), "k": 5}}]
175
+
176
  if "alias" in question:
177
  return [{"tool_name": "search_posts", "args": {"query": "Update"}}]
178
 
src/osint_env/cli.py CHANGED
@@ -11,6 +11,7 @@ from osint_env.env.environment import OSINTEnvironment
11
  from osint_env.env.reward import compute_graph_f1
12
  from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard, render_leaderboard_table
13
  from osint_env.eval.runner import run_evaluation
 
14
  from osint_env.viz import export_dashboard
15
 
16
 
@@ -24,6 +25,24 @@ def _add_common_args(parser: argparse.ArgumentParser) -> None:
24
  choices=["config", "single", "swarm"],
25
  help="Use shared config mode or override runner mode explicitly.",
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def build_parser() -> argparse.ArgumentParser:
@@ -88,6 +107,21 @@ def _resolve_environment_config(args: argparse.Namespace) -> tuple[EnvironmentCo
88
  if args.seed_file:
89
  env_cfg.seeding = load_seeding_config(args.seed_file)
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  if args.agent_mode == "single":
92
  env_cfg.swarm.enabled = False
93
  elif args.agent_mode == "swarm":
@@ -104,8 +138,8 @@ def _resolve_environment_config(args: argparse.Namespace) -> tuple[EnvironmentCo
104
 
105
  def _runner_for(env: OSINTEnvironment) -> SingleAgentRunner | SwarmAgentRunner:
106
  if env.config.swarm.enabled:
107
- return SwarmAgentRunner(env)
108
- return SingleAgentRunner(env)
109
 
110
 
111
  def main() -> None:
@@ -130,8 +164,8 @@ def main() -> None:
130
  for seed in seed_values:
131
  seeded_cfg = clone_environment_config(env_cfg)
132
  seeded_cfg.seed = seed
133
- env = OSINTEnvironment(seeded_cfg)
134
- evaluation = run_evaluation(env, episodes=episodes, return_details=True)
135
  summary = evaluation["summary"]
136
  run_name = f"{args.name_prefix}_seed{seed}"
137
  record = append_leaderboard_record(
@@ -171,15 +205,16 @@ def main() -> None:
171
  )
172
  return
173
 
174
- env = OSINTEnvironment(env_cfg)
 
175
  if args.cmd == "demo":
176
  info = _runner_for(env).run_episode()
177
  print(json.dumps(info, indent=2, sort_keys=True))
178
  elif args.cmd == "eval":
179
- metrics = run_evaluation(env, episodes=episodes)
180
  print(json.dumps(metrics, indent=2, sort_keys=True))
181
  elif args.cmd == "benchmark":
182
- evaluation = run_evaluation(env, episodes=episodes, return_details=True)
183
  summary = evaluation["summary"]
184
  record = append_leaderboard_record(
185
  path=leaderboard_path,
@@ -212,7 +247,7 @@ def main() -> None:
212
  print(json.dumps(payload, indent=2, sort_keys=True))
213
  elif args.cmd == "viz":
214
  if args.with_demo:
215
- SingleAgentRunner(env).run_episode()
216
 
217
  graph_f1 = 0.0
218
  if env.state is not None:
 
11
  from osint_env.env.reward import compute_graph_f1
12
  from osint_env.eval.leaderboard import append_leaderboard_record, load_leaderboard, render_leaderboard_table
13
  from osint_env.eval.runner import run_evaluation
14
+ from osint_env.llm import build_llm_client
15
  from osint_env.viz import export_dashboard
16
 
17
 
 
25
  choices=["config", "single", "swarm"],
26
  help="Use shared config mode or override runner mode explicitly.",
27
  )
28
+ parser.add_argument(
29
+ "--llm-provider",
30
+ type=str,
31
+ default="config",
32
+ choices=["config", "mock", "ollama", "openai"],
33
+ help="Use shared config provider or override explicitly.",
34
+ )
35
+ parser.add_argument("--llm-model", type=str, default="", help="Override model name for selected LLM provider.")
36
+ parser.add_argument("--llm-timeout-seconds", type=int, default=0, help="Override LLM request timeout in seconds.")
37
+ parser.add_argument("--ollama-base-url", type=str, default="", help="Override Ollama base URL.")
38
+ parser.add_argument("--openai-base-url", type=str, default="", help="Override OpenAI base URL.")
39
+ parser.add_argument("--openai-api-key", type=str, default="", help="OpenAI API key override.")
40
+ parser.add_argument(
41
+ "--openai-api-key-env",
42
+ type=str,
43
+ default="",
44
+ help="Environment variable name for OpenAI API key.",
45
+ )
46
 
47
 
48
  def build_parser() -> argparse.ArgumentParser:
 
107
  if args.seed_file:
108
  env_cfg.seeding = load_seeding_config(args.seed_file)
109
 
110
+ if args.llm_provider != "config":
111
+ env_cfg.llm.provider = args.llm_provider
112
+ if args.llm_model:
113
+ env_cfg.llm.model = args.llm_model
114
+ if int(args.llm_timeout_seconds) > 0:
115
+ env_cfg.llm.timeout_seconds = int(args.llm_timeout_seconds)
116
+ if args.ollama_base_url:
117
+ env_cfg.llm.ollama_base_url = args.ollama_base_url
118
+ if args.openai_base_url:
119
+ env_cfg.llm.openai_base_url = args.openai_base_url
120
+ if args.openai_api_key:
121
+ env_cfg.llm.openai_api_key = args.openai_api_key
122
+ if args.openai_api_key_env:
123
+ env_cfg.llm.openai_api_key_env = args.openai_api_key_env
124
+
125
  if args.agent_mode == "single":
126
  env_cfg.swarm.enabled = False
127
  elif args.agent_mode == "swarm":
 
138
 
139
  def _runner_for(env: OSINTEnvironment) -> SingleAgentRunner | SwarmAgentRunner:
140
  if env.config.swarm.enabled:
141
+ return SwarmAgentRunner(env, llm=build_llm_client(env.config.llm))
142
+ return SingleAgentRunner(env, llm=build_llm_client(env.config.llm))
143
 
144
 
145
  def main() -> None:
 
164
  for seed in seed_values:
165
  seeded_cfg = clone_environment_config(env_cfg)
166
  seeded_cfg.seed = seed
167
+ env = OSINTEnvironment(seeded_cfg, llm=build_llm_client(seeded_cfg.llm))
168
+ evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=build_llm_client(seeded_cfg.llm))
169
  summary = evaluation["summary"]
170
  run_name = f"{args.name_prefix}_seed{seed}"
171
  record = append_leaderboard_record(
 
205
  )
206
  return
207
 
208
+ llm_client = build_llm_client(env_cfg.llm)
209
+ env = OSINTEnvironment(env_cfg, llm=llm_client)
210
  if args.cmd == "demo":
211
  info = _runner_for(env).run_episode()
212
  print(json.dumps(info, indent=2, sort_keys=True))
213
  elif args.cmd == "eval":
214
+ metrics = run_evaluation(env, episodes=episodes, llm=llm_client)
215
  print(json.dumps(metrics, indent=2, sort_keys=True))
216
  elif args.cmd == "benchmark":
217
+ evaluation = run_evaluation(env, episodes=episodes, return_details=True, llm=llm_client)
218
  summary = evaluation["summary"]
219
  record = append_leaderboard_record(
220
  path=leaderboard_path,
 
247
  print(json.dumps(payload, indent=2, sort_keys=True))
248
  elif args.cmd == "viz":
249
  if args.with_demo:
250
+ _runner_for(env).run_episode()
251
 
252
  graph_f1 = 0.0
253
  if env.state is not None:
src/osint_env/config/shared.py CHANGED
@@ -8,6 +8,7 @@ from typing import Any
8
 
9
  from osint_env.domain.models import (
10
  EnvironmentConfig,
 
11
  NodeType,
12
  SeedingConfig,
13
  SeedEdgeSpec,
@@ -154,6 +155,13 @@ def _parse_seeding(data: dict[str, Any]) -> SeedingConfig:
154
  llm_generate_remaining_tasks=_parse_bool(data.get("llm_generate_remaining_tasks"), True),
155
  llm_generated_edge_budget=max(0, _parse_int(data.get("llm_generated_edge_budget"), 6)),
156
  llm_generated_task_budget=max(0, _parse_int(data.get("llm_generated_task_budget"), 8)),
 
 
 
 
 
 
 
157
  )
158
 
159
 
@@ -170,6 +178,7 @@ def _parse_environment(payload: dict[str, Any]) -> EnvironmentConfig:
170
  swarm_data = _as_dict(payload.get("swarm", env_data.get("swarm", {})))
171
  spawn_data = _as_dict(payload.get("spawn_reward", env_data.get("spawn_reward", {})))
172
  seeding_data = _as_dict(payload.get("seeding", env_data.get("seeding", {})))
 
173
 
174
  env = EnvironmentConfig(
175
  n_users=max(4, _parse_int(env_data.get("n_users"), 40)),
@@ -198,6 +207,19 @@ def _parse_environment(payload: dict[str, Any]) -> EnvironmentConfig:
198
  )
199
 
200
  env.seeding = _parse_seeding(seeding_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  return env
202
 
203
 
 
8
 
9
  from osint_env.domain.models import (
10
  EnvironmentConfig,
11
+ LLMConfig,
12
  NodeType,
13
  SeedingConfig,
14
  SeedEdgeSpec,
 
155
  llm_generate_remaining_tasks=_parse_bool(data.get("llm_generate_remaining_tasks"), True),
156
  llm_generated_edge_budget=max(0, _parse_int(data.get("llm_generated_edge_budget"), 6)),
157
  llm_generated_task_budget=max(0, _parse_int(data.get("llm_generated_task_budget"), 8)),
158
+ llm_generation_parallel=_parse_bool(data.get("llm_generation_parallel"), True),
159
+ llm_generation_workers=max(1, _parse_int(data.get("llm_generation_workers"), 3)),
160
+ llm_generation_retries=max(1, _parse_int(data.get("llm_generation_retries"), 2)),
161
+ allow_template_fallback_on_llm_failure=_parse_bool(
162
+ data.get("allow_template_fallback_on_llm_failure"),
163
+ False,
164
+ ),
165
  )
166
 
167
 
 
178
  swarm_data = _as_dict(payload.get("swarm", env_data.get("swarm", {})))
179
  spawn_data = _as_dict(payload.get("spawn_reward", env_data.get("spawn_reward", {})))
180
  seeding_data = _as_dict(payload.get("seeding", env_data.get("seeding", {})))
181
+ llm_data = _as_dict(payload.get("llm", env_data.get("llm", {})))
182
 
183
  env = EnvironmentConfig(
184
  n_users=max(4, _parse_int(env_data.get("n_users"), 40)),
 
207
  )
208
 
209
  env.seeding = _parse_seeding(seeding_data)
210
+ env.llm = LLMConfig(
211
+ provider=str(llm_data.get("provider", "mock")).strip() or "mock",
212
+ model=str(llm_data.get("model", "qwen3:2b")).strip() or "qwen3:2b",
213
+ temperature=_parse_float(llm_data.get("temperature"), 0.1),
214
+ max_tokens=max(1, _parse_int(llm_data.get("max_tokens"), 256)),
215
+ timeout_seconds=max(1, _parse_int(llm_data.get("timeout_seconds"), 240)),
216
+ ollama_base_url=str(llm_data.get("ollama_base_url", "http://127.0.0.1:11434")).strip()
217
+ or "http://127.0.0.1:11434",
218
+ openai_base_url=str(llm_data.get("openai_base_url", "https://api.openai.com/v1")).strip()
219
+ or "https://api.openai.com/v1",
220
+ openai_api_key_env=str(llm_data.get("openai_api_key_env", "OPENAI_API_KEY")).strip() or "OPENAI_API_KEY",
221
+ openai_api_key=str(llm_data.get("openai_api_key", "")).strip(),
222
+ )
223
  return env
224
 
225
 
src/osint_env/data/generator.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  import json
4
  import random
5
  import re
@@ -133,6 +134,72 @@ class DatasetGenerator:
133
  items.append(SeedEdgeSpec(src=src, rel=rel, dst=dst, confidence=confidence))
134
  return items
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def _template_generated_edges(self, graph: CanonicalGraph, budget: int) -> list[Edge]:
137
  if budget <= 0:
138
  return []
@@ -168,30 +235,79 @@ class DatasetGenerator:
168
  if self.llm is None:
169
  return self._template_generated_edges(graph, budget)
170
 
171
- sample_edges = [
172
- {"src": edge.src, "rel": edge.rel, "dst": edge.dst}
173
- for edge in graph.edges[: min(40, len(graph.edges))]
174
- ]
175
- sample_nodes = sorted(graph.nodes.keys())[:80]
176
- prompt = (
177
- "SEED_GRAPH_EXPANSION\n"
178
- "Generate additional plausible graph edges to improve retrieval for OSINT tasks.\n"
179
- "Return STRICT JSON object: {\"edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}.\n"
180
- "Use only known node ids when possible. Avoid duplicates.\n"
181
- f"Budget: {budget}\n"
182
- f"Known nodes: {json.dumps(sample_nodes)}\n"
183
- f"Known edges sample: {json.dumps(sample_edges)}"
184
- )
185
- response = self.llm.generate([{"role": "system", "content": prompt}], tools=[])
186
- parsed = self._extract_json_blob(response.content)
187
- if isinstance(parsed, dict):
188
- edges = self._normalize_edge_candidates(parsed.get("edges"))
189
- if edges:
190
- return [
191
- Edge(src=e.src, rel=e.rel, dst=e.dst, confidence=float(e.confidence))
192
- for e in edges[:budget]
193
- ]
194
- return self._template_generated_edges(graph, budget)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  @staticmethod
197
  def _extract_entity_tokens(question: str) -> list[str]:
@@ -311,24 +427,53 @@ class DatasetGenerator:
311
  for edge in graph.edges
312
  if edge.rel in {"alias_of", "connected_to", "works_at"}
313
  ][:60]
314
- prompt = (
315
- "SEED_TASK_EXPANSION\n"
316
- "Generate additional OSINT QA tasks from this graph sample.\n"
317
- "Return STRICT JSON object: {\"tasks\": [{\"task_type\": str, \"question\": str, \"answer\": str, \"supporting_edges\": [{\"src\": str, \"rel\": str, \"dst\": str}]}]}.\n"
318
- f"Task budget: {count}\n"
319
- f"Edge sample: {json.dumps(candidate_edges)}"
320
  )
321
- response = self.llm.generate([{"role": "system", "content": prompt}], tools=[])
322
- parsed = self._extract_json_blob(response.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  llm_tasks: list[TaskInstance] = []
325
- if isinstance(parsed, dict) and isinstance(parsed.get("tasks"), list):
326
- for i, row in enumerate(parsed["tasks"]):
 
 
 
 
 
 
 
 
 
327
  if not isinstance(row, dict):
328
  continue
329
  question = str(row.get("question", "")).strip()
330
  if not question:
331
  continue
 
 
 
 
332
  answer = str(row.get("answer", "")).strip() or self._infer_answer_from_question(question, graph)
333
  task_type = str(row.get("task_type", "llm_generated")).strip() or "llm_generated"
334
  support_specs = self._normalize_edge_candidates(row.get("supporting_edges"))
@@ -338,18 +483,63 @@ class DatasetGenerator:
338
  support = self._infer_support_edges(question, answer, graph)
339
  llm_tasks.append(
340
  TaskInstance(
341
- task_id=f"task_{start_idx + i}",
342
  task_type=task_type,
343
  question=question,
344
  answer=answer,
345
  supporting_edges=support,
346
- metadata={"generated_by": "llm"},
347
  )
348
  )
349
  if len(llm_tasks) >= count:
350
  break
 
 
351
 
352
  if len(llm_tasks) < count:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  llm_tasks.extend(
354
  self._template_tasks(
355
  graph,
@@ -464,7 +654,7 @@ class DatasetGenerator:
464
 
465
  def generate_tasks(self, graph: CanonicalGraph, views: PlatformViews, count: int = 12) -> list[TaskInstance]:
466
  tasks = self._seeded_tasks(graph)
467
- target_count = max(count, len(tasks))
468
 
469
  llm_budget = min(
470
  max(0, self.config.seeding.llm_generated_task_budget),
@@ -473,7 +663,10 @@ class DatasetGenerator:
473
  if self.config.seeding.llm_generate_remaining_tasks and llm_budget > 0:
474
  tasks.extend(self._llm_generated_tasks(graph, count=llm_budget, start_idx=len(tasks)))
475
 
476
- if len(tasks) < target_count:
477
  tasks.extend(self._template_tasks(graph, count=target_count - len(tasks), start_idx=len(tasks)))
478
 
 
 
 
479
  return tasks[:target_count]
 
1
  from __future__ import annotations
2
 
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
  import json
5
  import random
6
  import re
 
134
  items.append(SeedEdgeSpec(src=src, rel=rel, dst=dst, confidence=confidence))
135
  return items
136
 
137
+ @staticmethod
138
+ def _split_budget(total: int, parts: int) -> list[int]:
139
+ if total <= 0:
140
+ return []
141
+ slots = max(1, parts)
142
+ base = total // slots
143
+ remainder = total % slots
144
+ chunks = [base + (1 if i < remainder else 0) for i in range(slots)]
145
+ return [chunk for chunk in chunks if chunk > 0]
146
+
147
+ @staticmethod
148
+ def _shared_context_blob(graph: CanonicalGraph, node_limit: int = 100, edge_limit: int = 80) -> str:
149
+ payload = {
150
+ "known_nodes": sorted(graph.nodes.keys())[:node_limit],
151
+ "known_edges": [
152
+ {"src": edge.src, "rel": edge.rel, "dst": edge.dst}
153
+ for edge in graph.edges[: min(edge_limit, len(graph.edges))]
154
+ ],
155
+ }
156
+ return json.dumps(payload)
157
+
158
+ def _llm_generate_json_with_retry(self, prompt: str) -> Any:
159
+ if self.llm is None:
160
+ return None
161
+
162
+ attempts = max(1, int(self.config.seeding.llm_generation_retries))
163
+ for _ in range(attempts):
164
+ try:
165
+ response = self.llm.generate([{"role": "system", "content": prompt}], tools=[])
166
+ except Exception:
167
+ continue
168
+ parsed = self._extract_json_blob(response.content)
169
+ if parsed is not None:
170
+ return parsed
171
+ return None
172
+
173
+ def _run_generation_workers(self, prompts: list[str]) -> list[Any]:
174
+ if not prompts:
175
+ return []
176
+
177
+ max_workers = max(1, min(self.config.seeding.llm_generation_workers, len(prompts)))
178
+ if not self.config.seeding.llm_generation_parallel or max_workers == 1:
179
+ output: list[Any] = []
180
+ for prompt in prompts:
181
+ parsed = self._llm_generate_json_with_retry(prompt)
182
+ if parsed is not None:
183
+ output.append(parsed)
184
+ return output
185
+
186
+ output = []
187
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
188
+ futures = [executor.submit(self._llm_generate_json_with_retry, prompt) for prompt in prompts]
189
+ for future in as_completed(futures):
190
+ try:
191
+ parsed = future.result()
192
+ except Exception:
193
+ parsed = None
194
+ if parsed is not None:
195
+ output.append(parsed)
196
+ return output
197
+
198
+ def _template_fallback_allowed(self) -> bool:
199
+ if self.llm is None:
200
+ return True
201
+ return bool(self.config.seeding.allow_template_fallback_on_llm_failure)
202
+
203
  def _template_generated_edges(self, graph: CanonicalGraph, budget: int) -> list[Edge]:
204
  if budget <= 0:
205
  return []
 
235
  if self.llm is None:
236
  return self._template_generated_edges(graph, budget)
237
 
238
+ shared_context = self._shared_context_blob(graph)
239
+ workers = max(1, min(self.config.seeding.llm_generation_workers, budget))
240
+ chunks = self._split_budget(budget, workers)
241
+ focus_tracks = ["entity_linking", "network_expansion", "org_location", "event_trace"]
242
+
243
+ prompts: list[str] = []
244
+ for idx, chunk_budget in enumerate(chunks):
245
+ focus = focus_tracks[idx % len(focus_tracks)]
246
+ prompts.append(
247
+ (
248
+ "SEED_GRAPH_EXPANSION_AGENT\n"
249
+ "SHARED_CONTEXT\n"
250
+ f"{shared_context}\n"
251
+ f"worker_id: {idx}\n"
252
+ f"focus: {focus}\n"
253
+ f"budget: {chunk_budget}\n"
254
+ "Generate plausible graph edges for OSINT retrieval.\n"
255
+ "Return STRICT JSON object: {\"edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}.\n"
256
+ "Prefer known nodes from SHARED_CONTEXT and avoid duplicates."
257
+ )
258
+ )
259
+
260
+ generated: list[Edge] = []
261
+ seen: set[tuple[str, str, str]] = set()
262
+ for payload in self._run_generation_workers(prompts):
263
+ raw_edges: Any = None
264
+ if isinstance(payload, dict):
265
+ raw_edges = payload.get("edges")
266
+ elif isinstance(payload, list):
267
+ raw_edges = payload
268
+ for edge_spec in self._normalize_edge_candidates(raw_edges):
269
+ key = (edge_spec.src, edge_spec.rel, edge_spec.dst)
270
+ if key in seen:
271
+ continue
272
+ seen.add(key)
273
+ generated.append(Edge(edge_spec.src, edge_spec.rel, edge_spec.dst, float(edge_spec.confidence)))
274
+ if len(generated) >= budget:
275
+ break
276
+ if len(generated) >= budget:
277
+ break
278
+
279
+ if len(generated) < budget:
280
+ residual = budget - len(generated)
281
+ residual_prompt = (
282
+ "SEED_GRAPH_EXPANSION_AGENT\n"
283
+ "SHARED_CONTEXT\n"
284
+ f"{shared_context}\n"
285
+ f"budget: {residual}\n"
286
+ "Generate any remaining high-utility edges.\n"
287
+ "Return STRICT JSON object: {\"edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}."
288
+ )
289
+ payload = self._llm_generate_json_with_retry(residual_prompt)
290
+ raw_edges: Any = payload.get("edges") if isinstance(payload, dict) else payload
291
+ for edge_spec in self._normalize_edge_candidates(raw_edges):
292
+ key = (edge_spec.src, edge_spec.rel, edge_spec.dst)
293
+ if key in seen:
294
+ continue
295
+ seen.add(key)
296
+ generated.append(Edge(edge_spec.src, edge_spec.rel, edge_spec.dst, float(edge_spec.confidence)))
297
+ if len(generated) >= budget:
298
+ break
299
+
300
+ if len(generated) < budget and self._template_fallback_allowed():
301
+ for edge in self._template_generated_edges(graph, budget - len(generated)):
302
+ key = (edge.src, edge.rel, edge.dst)
303
+ if key in seen:
304
+ continue
305
+ seen.add(key)
306
+ generated.append(edge)
307
+ if len(generated) >= budget:
308
+ break
309
+
310
+ return generated[:budget]
311
 
312
  @staticmethod
313
  def _extract_entity_tokens(question: str) -> list[str]:
 
427
  for edge in graph.edges
428
  if edge.rel in {"alias_of", "connected_to", "works_at"}
429
  ][:60]
430
+ shared_context = json.dumps(
431
+ {
432
+ "known_nodes": sorted(graph.nodes.keys())[:100],
433
+ "edge_sample": candidate_edges,
434
+ }
 
435
  )
436
+ workers = max(1, min(self.config.seeding.llm_generation_workers, count))
437
+ chunks = self._split_budget(count, workers)
438
+ focus_tracks = ["identity_resolution", "network_discovery", "event_tracing", "deanonymization"]
439
+
440
+ prompts: list[str] = []
441
+ for idx, chunk_budget in enumerate(chunks):
442
+ focus = focus_tracks[idx % len(focus_tracks)]
443
+ prompts.append(
444
+ (
445
+ "SEED_TASK_EXPANSION_AGENT\n"
446
+ "SHARED_CONTEXT\n"
447
+ f"{shared_context}\n"
448
+ f"worker_id: {idx}\n"
449
+ f"focus: {focus}\n"
450
+ f"task_budget: {chunk_budget}\n"
451
+ "Generate OSINT QA tasks with answers and support edges.\n"
452
+ "Return STRICT JSON object: {\"tasks\": [{\"task_type\": str, \"question\": str, \"answer\": str, \"supporting_edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}]}."
453
+ )
454
+ )
455
 
456
  llm_tasks: list[TaskInstance] = []
457
+ seen_questions: set[str] = set()
458
+ for payload in self._run_generation_workers(prompts):
459
+ raw_tasks: Any = None
460
+ if isinstance(payload, dict):
461
+ raw_tasks = payload.get("tasks")
462
+ elif isinstance(payload, list):
463
+ raw_tasks = payload
464
+ if not isinstance(raw_tasks, list):
465
+ continue
466
+
467
+ for row in raw_tasks:
468
  if not isinstance(row, dict):
469
  continue
470
  question = str(row.get("question", "")).strip()
471
  if not question:
472
  continue
473
+ key = question.lower()
474
+ if key in seen_questions:
475
+ continue
476
+ seen_questions.add(key)
477
  answer = str(row.get("answer", "")).strip() or self._infer_answer_from_question(question, graph)
478
  task_type = str(row.get("task_type", "llm_generated")).strip() or "llm_generated"
479
  support_specs = self._normalize_edge_candidates(row.get("supporting_edges"))
 
483
  support = self._infer_support_edges(question, answer, graph)
484
  llm_tasks.append(
485
  TaskInstance(
486
+ task_id=f"task_{start_idx + len(llm_tasks)}",
487
  task_type=task_type,
488
  question=question,
489
  answer=answer,
490
  supporting_edges=support,
491
+ metadata={"generated_by": "llm", "shared_context": True},
492
  )
493
  )
494
  if len(llm_tasks) >= count:
495
  break
496
+ if len(llm_tasks) >= count:
497
+ break
498
 
499
  if len(llm_tasks) < count:
500
+ residual = count - len(llm_tasks)
501
+ residual_prompt = (
502
+ "SEED_TASK_EXPANSION_AGENT\n"
503
+ "SHARED_CONTEXT\n"
504
+ f"{shared_context}\n"
505
+ f"task_budget: {residual}\n"
506
+ "Generate additional tasks not already present in SHARED_CONTEXT.\n"
507
+ "Return STRICT JSON object: {\"tasks\": [{\"task_type\": str, \"question\": str, \"answer\": str, \"supporting_edges\": [{\"src\": str, \"rel\": str, \"dst\": str, \"confidence\": float}]}]}."
508
+ )
509
+ payload = self._llm_generate_json_with_retry(residual_prompt)
510
+ raw_tasks: Any = payload.get("tasks") if isinstance(payload, dict) else payload
511
+ if isinstance(raw_tasks, list):
512
+ for row in raw_tasks:
513
+ if not isinstance(row, dict):
514
+ continue
515
+ question = str(row.get("question", "")).strip()
516
+ if not question:
517
+ continue
518
+ key = question.lower()
519
+ if key in seen_questions:
520
+ continue
521
+ seen_questions.add(key)
522
+ answer = str(row.get("answer", "")).strip() or self._infer_answer_from_question(question, graph)
523
+ task_type = str(row.get("task_type", "llm_generated")).strip() or "llm_generated"
524
+ support_specs = self._normalize_edge_candidates(row.get("supporting_edges"))
525
+ if support_specs:
526
+ support = [Edge(e.src, e.rel, e.dst, e.confidence) for e in support_specs]
527
+ else:
528
+ support = self._infer_support_edges(question, answer, graph)
529
+ llm_tasks.append(
530
+ TaskInstance(
531
+ task_id=f"task_{start_idx + len(llm_tasks)}",
532
+ task_type=task_type,
533
+ question=question,
534
+ answer=answer,
535
+ supporting_edges=support,
536
+ metadata={"generated_by": "llm", "shared_context": True},
537
+ )
538
+ )
539
+ if len(llm_tasks) >= count:
540
+ break
541
+
542
+ if len(llm_tasks) < count and self._template_fallback_allowed():
543
  llm_tasks.extend(
544
  self._template_tasks(
545
  graph,
 
654
 
655
  def generate_tasks(self, graph: CanonicalGraph, views: PlatformViews, count: int = 12) -> list[TaskInstance]:
656
  tasks = self._seeded_tasks(graph)
657
+ target_count = max(1, count, len(tasks))
658
 
659
  llm_budget = min(
660
  max(0, self.config.seeding.llm_generated_task_budget),
 
663
  if self.config.seeding.llm_generate_remaining_tasks and llm_budget > 0:
664
  tasks.extend(self._llm_generated_tasks(graph, count=llm_budget, start_idx=len(tasks)))
665
 
666
+ if len(tasks) < target_count and self._template_fallback_allowed():
667
  tasks.extend(self._template_tasks(graph, count=target_count - len(tasks), start_idx=len(tasks)))
668
 
669
+ if not tasks:
670
+ tasks.extend(self._template_tasks(graph, count=target_count, start_idx=0))
671
+
672
  return tasks[:target_count]
src/osint_env/domain/models.py CHANGED
@@ -105,6 +105,10 @@ class SeedingConfig:
105
  llm_generate_remaining_tasks: bool = True
106
  llm_generated_edge_budget: int = 6
107
  llm_generated_task_budget: int = 8
 
 
 
 
108
 
109
 
110
  @dataclass(slots=True)
@@ -126,6 +130,19 @@ class SpawnRewardConfig:
126
  max_parallel_hint: int = 3
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @dataclass(slots=True)
130
  class EnvironmentConfig:
131
  n_users: int = 40
@@ -137,3 +154,4 @@ class EnvironmentConfig:
137
  seeding: SeedingConfig = field(default_factory=SeedingConfig)
138
  swarm: SwarmConfig = field(default_factory=SwarmConfig)
139
  spawn_reward: SpawnRewardConfig = field(default_factory=SpawnRewardConfig)
 
 
105
  llm_generate_remaining_tasks: bool = True
106
  llm_generated_edge_budget: int = 6
107
  llm_generated_task_budget: int = 8
108
+ llm_generation_parallel: bool = True
109
+ llm_generation_workers: int = 3
110
+ llm_generation_retries: int = 2
111
+ allow_template_fallback_on_llm_failure: bool = False
112
 
113
 
114
  @dataclass(slots=True)
 
130
  max_parallel_hint: int = 3
131
 
132
 
133
+ @dataclass(slots=True)
134
+ class LLMConfig:
135
+ provider: str = "mock"
136
+ model: str = "qwen3:2b"
137
+ temperature: float = 0.1
138
+ max_tokens: int = 256
139
+ timeout_seconds: int = 240
140
+ ollama_base_url: str = "http://127.0.0.1:11434"
141
+ openai_base_url: str = "https://api.openai.com/v1"
142
+ openai_api_key_env: str = "OPENAI_API_KEY"
143
+ openai_api_key: str = ""
144
+
145
+
146
  @dataclass(slots=True)
147
  class EnvironmentConfig:
148
  n_users: int = 40
 
154
  seeding: SeedingConfig = field(default_factory=SeedingConfig)
155
  swarm: SwarmConfig = field(default_factory=SwarmConfig)
156
  spawn_reward: SpawnRewardConfig = field(default_factory=SpawnRewardConfig)
157
+ llm: LLMConfig = field(default_factory=LLMConfig)
src/osint_env/env/environment.py CHANGED
@@ -103,15 +103,27 @@ class OSINTEnvironment(Env):
103
  penalty = 0.05
104
  self.state.call_fingerprints.add(fp)
105
 
106
- output = self.tools.call(tool_name, args)
 
 
 
 
 
 
 
 
 
 
 
107
  self.state.tool_outputs.append({"tool": tool_name, "args": args, "output": output})
108
  self.semantic_memory.add(f"{tool_name} {args} {output}", {"tool": tool_name})
109
  relevance_bonus = 0.08 * self._tool_relevance(self.state.task, output)
110
- total = penalty + relevance_bonus
111
  self._accumulate_reward_components(
112
  {
113
  "tool_novelty": penalty,
114
  "tool_relevance": relevance_bonus,
 
115
  }
116
  )
117
  return total
 
103
  penalty = 0.05
104
  self.state.call_fingerprints.add(fp)
105
 
106
+ invalid_tool_penalty = 0.0
107
+ try:
108
+ if tool_name == "search_memory":
109
+ query = str(args.get("query", "")).strip()
110
+ top_k = int(args.get("k", 5)) if str(args.get("k", "")).strip() else 5
111
+ results = self.semantic_memory.search(query=query, k=max(1, top_k)) if query else []
112
+ output = {"results": results, "count": len(results)}
113
+ else:
114
+ output = self.tools.call(tool_name, args)
115
+ except Exception as exc:
116
+ output = {"error": str(exc)}
117
+ invalid_tool_penalty = -0.25
118
  self.state.tool_outputs.append({"tool": tool_name, "args": args, "output": output})
119
  self.semantic_memory.add(f"{tool_name} {args} {output}", {"tool": tool_name})
120
  relevance_bonus = 0.08 * self._tool_relevance(self.state.task, output)
121
+ total = penalty + relevance_bonus + invalid_tool_penalty
122
  self._accumulate_reward_components(
123
  {
124
  "tool_novelty": penalty,
125
  "tool_relevance": relevance_bonus,
126
+ "invalid_tool_penalty": invalid_tool_penalty,
127
  }
128
  )
129
  return total
src/osint_env/eval/runner.py CHANGED
@@ -5,14 +5,20 @@ from osint_env.agents.swarm_agent import SwarmAgentRunner
5
  from osint_env.env.environment import OSINTEnvironment
6
  from osint_env.env.reward import compute_graph_f1
7
  from osint_env.eval.metrics import EvalMetrics
 
8
 
9
 
10
- def run_evaluation(env: OSINTEnvironment, episodes: int = 20, return_details: bool = False) -> dict:
 
 
 
 
 
11
  metrics = EvalMetrics()
12
  if env.config.swarm.enabled:
13
- runner = SwarmAgentRunner(env=env)
14
  else:
15
- runner = SingleAgentRunner(env=env)
16
  episode_rows: list[dict] = []
17
  for _ in range(episodes):
18
  info = runner.run_episode()
 
5
  from osint_env.env.environment import OSINTEnvironment
6
  from osint_env.env.reward import compute_graph_f1
7
  from osint_env.eval.metrics import EvalMetrics
8
+ from osint_env.llm.interface import LLMClient
9
 
10
 
11
+ def run_evaluation(
12
+ env: OSINTEnvironment,
13
+ episodes: int = 20,
14
+ return_details: bool = False,
15
+ llm: LLMClient | None = None,
16
+ ) -> dict:
17
  metrics = EvalMetrics()
18
  if env.config.swarm.enabled:
19
+ runner = SwarmAgentRunner(env=env, llm=llm)
20
  else:
21
+ runner = SingleAgentRunner(env=env, llm=llm)
22
  episode_rows: list[dict] = []
23
  for _ in range(episodes):
24
  info = runner.run_episode()
src/osint_env/llm/__init__.py CHANGED
@@ -1,2 +1,20 @@
1
  """LLM interface package."""
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """LLM interface package."""
2
 
3
+ from osint_env.llm.interface import (
4
+ LLMClient,
5
+ LLMResponse,
6
+ OllamaLLMClient,
7
+ OpenAILLMClient,
8
+ RuleBasedMockLLM,
9
+ build_llm_client,
10
+ )
11
+
12
+ __all__ = [
13
+ "LLMClient",
14
+ "LLMResponse",
15
+ "RuleBasedMockLLM",
16
+ "OllamaLLMClient",
17
+ "OpenAILLMClient",
18
+ "build_llm_client",
19
+ ]
20
+
src/osint_env/llm/interface.py CHANGED
@@ -1,8 +1,15 @@
1
  from __future__ import annotations
2
 
 
 
3
  from dataclasses import dataclass
4
  from typing import Any, Protocol
5
 
 
 
 
 
 
6
 
7
  @dataclass(slots=True)
8
  class LLMResponse:
@@ -30,3 +37,131 @@ class RuleBasedMockLLM:
30
  tool_calls=[{"tool_name": "search_posts", "args": {"query": "Update"}}, {"tool_name": "get_profile", "args": {"user_id": "user_0"}}],
31
  )
32
  return LLMResponse(content="Need profile lookup.", tool_calls=[{"tool_name": "search_people", "args": {"org": "Apex"}}])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import json
4
+ import os
5
  from dataclasses import dataclass
6
  from typing import Any, Protocol
7
 
8
+ import requests
9
+ from requests import RequestException
10
+
11
+ from osint_env.domain.models import LLMConfig
12
+
13
 
14
  @dataclass(slots=True)
15
  class LLMResponse:
 
37
  tool_calls=[{"tool_name": "search_posts", "args": {"query": "Update"}}, {"tool_name": "get_profile", "args": {"user_id": "user_0"}}],
38
  )
39
  return LLMResponse(content="Need profile lookup.", tool_calls=[{"tool_name": "search_people", "args": {"org": "Apex"}}])
40
+
41
+
42
+ class OllamaLLMClient:
43
+ def __init__(self, model: str, base_url: str = "http://127.0.0.1:11434", temperature: float = 0.1, timeout_seconds: int = 240):
44
+ self.model = model
45
+ self.base_url = base_url.rstrip("/")
46
+ self.temperature = float(temperature)
47
+ self.timeout_seconds = int(timeout_seconds)
48
+
49
+ @staticmethod
50
+ def _extract_tool_calls(content: str) -> list[dict[str, Any]]:
51
+ text = str(content or "").strip()
52
+ if not text:
53
+ return []
54
+ left = text.find("{")
55
+ right = text.rfind("}")
56
+ if left >= 0 and right > left:
57
+ snippet = text[left : right + 1]
58
+ try:
59
+ parsed = json.loads(snippet)
60
+ except json.JSONDecodeError:
61
+ parsed = None
62
+ if isinstance(parsed, dict) and isinstance(parsed.get("tool_calls"), list):
63
+ out: list[dict[str, Any]] = []
64
+ for item in parsed["tool_calls"]:
65
+ if isinstance(item, dict) and "tool_name" in item and isinstance(item.get("args", {}), dict):
66
+ out.append({"tool_name": str(item["tool_name"]), "args": dict(item.get("args", {}))})
67
+ return out
68
+ return []
69
+
70
+ def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse:
71
+ payload = {
72
+ "model": self.model,
73
+ "messages": messages,
74
+ "stream": False,
75
+ "options": {
76
+ "temperature": self.temperature,
77
+ },
78
+ }
79
+ if tools:
80
+ payload["tools"] = tools
81
+ try:
82
+ response = requests.post(
83
+ f"{self.base_url}/api/chat",
84
+ json=payload,
85
+ timeout=self.timeout_seconds,
86
+ )
87
+ response.raise_for_status()
88
+ data = response.json()
89
+ content = str((data.get("message") or {}).get("content", ""))
90
+ tool_calls = self._extract_tool_calls(content)
91
+ return LLMResponse(content=content, tool_calls=tool_calls)
92
+ except (RequestException, ValueError):
93
+ # Keep episode execution resilient when local model calls are transiently slow/unavailable.
94
+ return LLMResponse(content="", tool_calls=[])
95
+
96
+
97
+ class OpenAILLMClient:
98
+ def __init__(
99
+ self,
100
+ model: str,
101
+ api_key: str,
102
+ base_url: str = "https://api.openai.com/v1",
103
+ temperature: float = 0.1,
104
+ max_tokens: int = 256,
105
+ timeout_seconds: int = 240,
106
+ ):
107
+ from openai import OpenAI
108
+
109
+ self.model = model
110
+ self.temperature = float(temperature)
111
+ self.max_tokens = int(max_tokens)
112
+ self.client = OpenAI(api_key=api_key, base_url=base_url, timeout=timeout_seconds)
113
+
114
+ def generate(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> LLMResponse:
115
+ kwargs: dict[str, Any] = {
116
+ "model": self.model,
117
+ "messages": messages,
118
+ "temperature": self.temperature,
119
+ "max_tokens": self.max_tokens,
120
+ }
121
+ if tools:
122
+ kwargs["tools"] = tools
123
+ try:
124
+ completion = self.client.chat.completions.create(**kwargs)
125
+ message = completion.choices[0].message
126
+ content = message.content if isinstance(message.content, str) else ""
127
+
128
+ tool_calls: list[dict[str, Any]] = []
129
+ for tc in message.tool_calls or []:
130
+ try:
131
+ args = json.loads(tc.function.arguments or "{}")
132
+ except json.JSONDecodeError:
133
+ args = {}
134
+ tool_calls.append({"tool_name": tc.function.name, "args": args if isinstance(args, dict) else {}})
135
+ return LLMResponse(content=content, tool_calls=tool_calls)
136
+ except Exception:
137
+ return LLMResponse(content="", tool_calls=[])
138
+
139
+
140
+ def build_llm_client(config: LLMConfig | None = None) -> LLMClient:
141
+ cfg = config or LLMConfig()
142
+ provider = str(cfg.provider).strip().lower()
143
+ if provider in {"", "mock", "rule", "rule_based"}:
144
+ return RuleBasedMockLLM()
145
+ if provider == "ollama":
146
+ return OllamaLLMClient(
147
+ model=cfg.model,
148
+ base_url=cfg.ollama_base_url,
149
+ temperature=cfg.temperature,
150
+ timeout_seconds=cfg.timeout_seconds,
151
+ )
152
+ if provider == "openai":
153
+ api_key = cfg.openai_api_key or os.getenv(cfg.openai_api_key_env, "")
154
+ if not api_key:
155
+ raise ValueError(
156
+ "OpenAI provider selected but API key is missing. "
157
+ f"Set {cfg.openai_api_key_env} or populate openai_api_key in config."
158
+ )
159
+ return OpenAILLMClient(
160
+ model=cfg.model,
161
+ api_key=api_key,
162
+ base_url=cfg.openai_base_url,
163
+ temperature=cfg.temperature,
164
+ max_tokens=cfg.max_tokens,
165
+ timeout_seconds=cfg.timeout_seconds,
166
+ )
167
+ raise ValueError(f"Unsupported llm provider: {cfg.provider}")
tests/test_config.py CHANGED
@@ -23,9 +23,14 @@ def test_shared_config_parses_swarm_and_seeding(tmp_path: Path):
23
  "question": "Which canonical user owns alias alias_seed_001?",
24
  "answer": "user_seed_001",
25
  }
26
- ]
 
 
 
 
27
  },
28
  "runtime": {"default_episodes": 5},
 
29
  }
30
  ),
31
  encoding="utf-8",
@@ -37,6 +42,13 @@ def test_shared_config_parses_swarm_and_seeding(tmp_path: Path):
37
  assert config.environment.swarm.max_width == 2
38
  assert len(config.environment.seeding.seeded_questions) == 1
39
  assert config.runtime.default_episodes == 5
 
 
 
 
 
 
 
40
 
41
 
42
  def test_load_seeding_config_supports_top_level_object(tmp_path: Path):
 
23
  "question": "Which canonical user owns alias alias_seed_001?",
24
  "answer": "user_seed_001",
25
  }
26
+ ],
27
+ "llm_generation_parallel": True,
28
+ "llm_generation_workers": 4,
29
+ "llm_generation_retries": 3,
30
+ "allow_template_fallback_on_llm_failure": False
31
  },
32
  "runtime": {"default_episodes": 5},
33
+ "llm": {"provider": "ollama", "model": "qwen3:2b", "timeout_seconds": 333},
34
  }
35
  ),
36
  encoding="utf-8",
 
42
  assert config.environment.swarm.max_width == 2
43
  assert len(config.environment.seeding.seeded_questions) == 1
44
  assert config.runtime.default_episodes == 5
45
+ assert config.environment.llm.provider == "ollama"
46
+ assert config.environment.llm.model == "qwen3:2b"
47
+ assert config.environment.llm.timeout_seconds == 333
48
+ assert config.environment.seeding.llm_generation_parallel is True
49
+ assert config.environment.seeding.llm_generation_workers == 4
50
+ assert config.environment.seeding.llm_generation_retries == 3
51
+ assert config.environment.seeding.allow_template_fallback_on_llm_failure is False
52
 
53
 
54
  def test_load_seeding_config_supports_top_level_object(tmp_path: Path):
tests/test_environment.py CHANGED
@@ -13,3 +13,25 @@ def test_episode_flow():
13
  assert done is True
14
  assert "total_reward" in info
15
  assert isinstance(r2, float)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  assert done is True
14
  assert "total_reward" in info
15
  assert isinstance(r2, float)
16
+
17
+
18
+ def test_search_memory_tool_returns_results_after_tool_use():
19
+ env = OSINTEnvironment(EnvironmentConfig(max_steps=6, seed=5))
20
+ env.reset()
21
+ env.step(Action(ActionType.CALL_TOOL, {"tool_name": "search_posts", "args": {"query": "Update"}}))
22
+ obs, reward, done, _ = env.step(
23
+ Action(ActionType.CALL_TOOL, {"tool_name": "search_memory", "args": {"query": "Update", "k": 3}})
24
+ )
25
+ assert done is False
26
+ assert isinstance(reward, float)
27
+ assert obs.tool_outputs[-1]["tool"] == "search_memory"
28
+ assert obs.tool_outputs[-1]["output"]["count"] >= 1
29
+
30
+
31
+ def test_invalid_tool_call_does_not_crash_episode():
32
+ env = OSINTEnvironment(EnvironmentConfig(max_steps=4, seed=8))
33
+ env.reset()
34
+ _, reward, done, info = env.step(Action(ActionType.CALL_TOOL, {"tool_name": "no_such_tool", "args": {}}))
35
+ assert done is False
36
+ assert reward < 0
37
+ assert "invalid_tool_penalty" in info["reward_components"]
tests/test_generator.py CHANGED
@@ -1,5 +1,63 @@
 
 
 
 
1
  from osint_env.data.generator import DatasetGenerator
2
  from osint_env.domain.models import EnvironmentConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def test_generator_outputs():
@@ -10,3 +68,46 @@ def test_generator_outputs():
10
  assert len(graph.nodes) >= 20
11
  assert len(views.microblog_posts) == 20
12
  assert len(tasks) == 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from threading import Lock
4
+
5
  from osint_env.data.generator import DatasetGenerator
6
  from osint_env.domain.models import EnvironmentConfig
7
+ from osint_env.llm.interface import LLMResponse
8
+
9
+
10
+ class SharedContextLLM:
11
+ def __init__(self):
12
+ self.prompts: list[str] = []
13
+ self._lock = Lock()
14
+
15
+ def generate(self, messages, tools):
16
+ prompt = str(messages[0].get("content", "")) if messages else ""
17
+ with self._lock:
18
+ self.prompts.append(prompt)
19
+
20
+ if "SEED_GRAPH_EXPANSION_AGENT" in prompt:
21
+ worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
22
+ worker_idx = int(worker_match.group(1)) if worker_match else 0
23
+ payload = {
24
+ "edges": [
25
+ {
26
+ "src": "user_0",
27
+ "rel": f"llm_rel_{worker_idx}",
28
+ "dst": "user_1",
29
+ "confidence": 0.9,
30
+ }
31
+ ]
32
+ }
33
+ return LLMResponse(content=json.dumps(payload), tool_calls=[])
34
+
35
+ if "SEED_TASK_EXPANSION_AGENT" in prompt:
36
+ worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
37
+ worker_idx = int(worker_match.group(1)) if worker_match else 0
38
+ budget_match = re.search(r"task_budget:\s*(\d+)", prompt)
39
+ task_budget = int(budget_match.group(1)) if budget_match else 1
40
+ tasks = []
41
+ for local_idx in range(max(1, task_budget)):
42
+ tasks.append(
43
+ {
44
+ "task_type": "identity_resolution",
45
+ "question": f"Which canonical user is tied to alias alias_seed_{worker_idx}_{local_idx}?",
46
+ "answer": "user_1",
47
+ "supporting_edges": [
48
+ {
49
+ "src": "alias_seed_0",
50
+ "rel": "alias_of",
51
+ "dst": "user_1",
52
+ "confidence": 0.95,
53
+ }
54
+ ],
55
+ }
56
+ )
57
+ payload = {"tasks": tasks}
58
+ return LLMResponse(content=json.dumps(payload), tool_calls=[])
59
+
60
+ return LLMResponse(content="{}", tool_calls=[])
61
 
62
 
63
  def test_generator_outputs():
 
68
  assert len(graph.nodes) >= 20
69
  assert len(views.microblog_posts) == 20
70
  assert len(tasks) == 5
71
+
72
+
73
+ def test_graph_generation_uses_parallel_shared_context_workers():
74
+ cfg = EnvironmentConfig(n_users=12, seed=9)
75
+ cfg.seeding.llm_generate_remaining_graph = True
76
+ cfg.seeding.llm_generated_edge_budget = 4
77
+ cfg.seeding.llm_generate_remaining_tasks = False
78
+ cfg.seeding.llm_generation_parallel = True
79
+ cfg.seeding.llm_generation_workers = 3
80
+ cfg.seeding.llm_generation_retries = 1
81
+ cfg.seeding.allow_template_fallback_on_llm_failure = False
82
+
83
+ llm = SharedContextLLM()
84
+ gen = DatasetGenerator(cfg, llm=llm)
85
+ graph = gen.build_canonical_graph()
86
+
87
+ assert any(edge.rel.startswith("llm_rel_") for edge in graph.edges)
88
+ graph_prompts = [prompt for prompt in llm.prompts if "SEED_GRAPH_EXPANSION_AGENT" in prompt]
89
+ assert len(graph_prompts) >= 2
90
+ assert all("SHARED_CONTEXT" in prompt for prompt in graph_prompts)
91
+
92
+
93
+ def test_task_generation_uses_parallel_shared_context_workers():
94
+ cfg = EnvironmentConfig(n_users=12, seed=13)
95
+ cfg.seeding.llm_generate_remaining_graph = False
96
+ cfg.seeding.llm_generate_remaining_tasks = True
97
+ cfg.seeding.llm_generated_task_budget = 4
98
+ cfg.seeding.llm_generation_parallel = True
99
+ cfg.seeding.llm_generation_workers = 3
100
+ cfg.seeding.llm_generation_retries = 1
101
+ cfg.seeding.allow_template_fallback_on_llm_failure = False
102
+
103
+ llm = SharedContextLLM()
104
+ gen = DatasetGenerator(cfg, llm=llm)
105
+ graph = gen.build_canonical_graph()
106
+ views = gen.build_platform_views(graph)
107
+ tasks = gen.generate_tasks(graph, views, count=4)
108
+
109
+ assert len(tasks) == 4
110
+ assert any(task.metadata.get("shared_context") for task in tasks)
111
+ task_prompts = [prompt for prompt in llm.prompts if "SEED_TASK_EXPANSION_AGENT" in prompt]
112
+ assert len(task_prompts) >= 2
113
+ assert all("SHARED_CONTEXT" in prompt for prompt in task_prompts)
tests/test_llm_interface.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pytest
4
+ import requests
5
+
6
+ from osint_env.domain.models import LLMConfig
7
+ from osint_env.llm.interface import OllamaLLMClient, RuleBasedMockLLM, build_llm_client
8
+
9
+
10
+ def test_build_llm_client_mock_default():
11
+ client = build_llm_client(LLMConfig(provider="mock"))
12
+ assert isinstance(client, RuleBasedMockLLM)
13
+
14
+
15
+ def test_build_llm_client_openai_requires_key(monkeypatch: pytest.MonkeyPatch):
16
+ monkeypatch.delenv("OPENAI_API_KEY", raising=False)
17
+ with pytest.raises(ValueError):
18
+ build_llm_client(LLMConfig(provider="openai", openai_api_key="", openai_api_key_env="OPENAI_API_KEY"))
19
+
20
+
21
+ def test_build_llm_client_openai_with_key(monkeypatch: pytest.MonkeyPatch):
22
+ monkeypatch.setenv("OPENAI_API_KEY", "test-key")
23
+ cfg = LLMConfig(provider="openai", model="gpt-4o-mini", openai_api_key_env="OPENAI_API_KEY")
24
+ # Constructing should not fail when a key is present; actual API call is not made in this test.
25
+ client = build_llm_client(cfg)
26
+ assert client is not None
27
+
28
+
29
+ def test_openai_key_can_come_from_config_value(monkeypatch: pytest.MonkeyPatch):
30
+ monkeypatch.delenv("OPENAI_API_KEY", raising=False)
31
+ cfg = LLMConfig(provider="openai", model="gpt-4o-mini", openai_api_key="cfg-key")
32
+ client = build_llm_client(cfg)
33
+ assert client is not None
34
+
35
+
36
+ def test_ollama_client_gracefully_handles_request_failure(monkeypatch: pytest.MonkeyPatch):
37
+ def _raise(*args, **kwargs):
38
+ raise requests.exceptions.ReadTimeout("timed out")
39
+
40
+ monkeypatch.setattr("osint_env.llm.interface.requests.post", _raise)
41
+ client = OllamaLLMClient(model="qwen3:2b", timeout_seconds=1)
42
+ response = client.generate([{"role": "system", "content": "ping"}], tools=[])
43
+ assert response.content == ""
44
+ assert response.tool_calls == []