Spaces:
Paused
Paused
File size: 6,509 Bytes
db4fa53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import json
import re
from threading import Lock
from osint_env.data.generator import (
DatasetGenerator,
build_swarm_v2_canonical_subgraph,
build_swarm_v2_path_candidates,
build_swarm_v2_tool_trace,
emit_swarm_v2_question,
select_swarm_v2_answer,
trace_swarm_v2_path,
)
from osint_env.domain.models import EnvironmentConfig
from osint_env.llm.interface import LLMResponse
class SharedContextLLM:
def __init__(self):
self.prompts: list[str] = []
self._lock = Lock()
def generate(self, messages, tools):
prompt = str(messages[0].get("content", "")) if messages else ""
with self._lock:
self.prompts.append(prompt)
if "SEED_GRAPH_EXPANSION_AGENT" in prompt:
worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
worker_idx = int(worker_match.group(1)) if worker_match else 0
payload = {
"edges": [
{
"src": "user_0",
"rel": f"llm_rel_{worker_idx}",
"dst": "user_1",
"confidence": 0.9,
}
]
}
return LLMResponse(content=json.dumps(payload), tool_calls=[])
if "SEED_TASK_EXPANSION_AGENT" in prompt:
worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
worker_idx = int(worker_match.group(1)) if worker_match else 0
budget_match = re.search(r"task_budget:\s*(\d+)", prompt)
task_budget = int(budget_match.group(1)) if budget_match else 1
tasks = []
for local_idx in range(max(1, task_budget)):
tasks.append(
{
"task_type": "identity_resolution",
"question": f"Which canonical user is tied to alias alias_seed_{worker_idx}_{local_idx}?",
"answer": "user_1",
"supporting_edges": [
{
"src": "alias_seed_0",
"rel": "alias_of",
"dst": "user_1",
"confidence": 0.95,
}
],
}
)
payload = {"tasks": tasks}
return LLMResponse(content=json.dumps(payload), tool_calls=[])
return LLMResponse(content="{}", tool_calls=[])
def test_generator_outputs():
gen = DatasetGenerator(EnvironmentConfig(n_users=20, seed=11))
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
tasks = gen.generate_tasks(graph, views, count=5)
assert len(graph.nodes) >= 20
assert len(views.microblog_posts) >= 20
assert len(tasks) == 5
def test_seeded_views_include_seeded_posts_and_threads():
from osint_env.config import clone_environment_config, load_seeding_config, load_shared_config
shared = load_shared_config("datasets/fixed_levels/shared_config_fixed_levels.json")
cfg = clone_environment_config(shared.environment)
cfg.seeding = load_seeding_config("datasets/fixed_levels/seed_fixed_levels.json")
cfg.llm.provider = "mock"
gen = DatasetGenerator(cfg)
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
seeded_post = next((post for post in views.microblog_posts if post["post_id"] == "post_midnight_manifest"), None)
seeded_thread = next((thread for thread in views.forum_threads if thread["thread_id"] == "thr_supply_leak"), None)
assert seeded_post is not None
assert "loc_dockyard17" in seeded_post["references"]
assert seeded_thread is not None
assert "org_northbridge_logistics" in seeded_thread["references"]
def test_graph_generation_uses_parallel_shared_context_workers():
cfg = EnvironmentConfig(n_users=12, seed=9)
cfg.seeding.llm_generate_remaining_graph = True
cfg.seeding.llm_generated_edge_budget = 4
cfg.seeding.llm_generate_remaining_tasks = False
cfg.seeding.llm_generation_parallel = True
cfg.seeding.llm_generation_workers = 3
cfg.seeding.llm_generation_retries = 1
cfg.seeding.allow_template_fallback_on_llm_failure = False
llm = SharedContextLLM()
gen = DatasetGenerator(cfg, llm=llm)
graph = gen.build_canonical_graph()
assert any(edge.rel.startswith("llm_rel_") for edge in graph.edges)
graph_prompts = [prompt for prompt in llm.prompts if "SEED_GRAPH_EXPANSION_AGENT" in prompt]
assert len(graph_prompts) >= 2
assert all("SHARED_CONTEXT" in prompt for prompt in graph_prompts)
def test_task_generation_uses_parallel_shared_context_workers():
cfg = EnvironmentConfig(n_users=12, seed=13)
cfg.seeding.llm_generate_remaining_graph = False
cfg.seeding.llm_generate_remaining_tasks = True
cfg.seeding.llm_generated_task_budget = 4
cfg.seeding.llm_generation_parallel = True
cfg.seeding.llm_generation_workers = 3
cfg.seeding.llm_generation_retries = 1
cfg.seeding.allow_template_fallback_on_llm_failure = False
llm = SharedContextLLM()
gen = DatasetGenerator(cfg, llm=llm)
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
tasks = gen.generate_tasks(graph, views, count=4)
assert len(tasks) == 4
assert any(task.metadata.get("shared_context") for task in tasks)
task_prompts = [prompt for prompt in llm.prompts if "SEED_TASK_EXPANSION_AGENT" in prompt]
assert len(task_prompts) >= 2
assert all("SHARED_CONTEXT" in prompt for prompt in task_prompts)
def test_swarm_v2_path_tools_replay_a_valid_multi_hop_trace():
gen = DatasetGenerator(EnvironmentConfig(n_users=20, seed=17))
graph = gen.build_canonical_graph()
candidates = build_swarm_v2_path_candidates(graph, gen.rng, count=4, min_hops=2, max_hops=3)
assert candidates
traced = trace_swarm_v2_path(graph, candidates[0])
assert traced
assert len(traced) >= 2
question = emit_swarm_v2_question(traced)
answer = select_swarm_v2_answer(traced)
tool_trace = build_swarm_v2_tool_trace(graph, traced)
canonical = build_swarm_v2_canonical_subgraph(graph, traced, max_extra_edges=2)
assert question.startswith("If you start at")
assert answer == traced[-1].dst
assert any(call["tool_name"] == "trace_path" for call in tool_trace)
assert canonical["path"]
assert canonical["answer"] == answer
|