Spaces:
Sleeping
Sleeping
File size: 5,457 Bytes
d6fbf54 aa360de d6fbf54 aa360de 9af411f aa360de d6fbf54 9af411f d6fbf54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import json
import re
from threading import Lock
from osint_env.data.generator import DatasetGenerator
from osint_env.domain.models import EnvironmentConfig
from osint_env.llm.interface import LLMResponse
class SharedContextLLM:
def __init__(self):
self.prompts: list[str] = []
self._lock = Lock()
def generate(self, messages, tools):
prompt = str(messages[0].get("content", "")) if messages else ""
with self._lock:
self.prompts.append(prompt)
if "SEED_GRAPH_EXPANSION_AGENT" in prompt:
worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
worker_idx = int(worker_match.group(1)) if worker_match else 0
payload = {
"edges": [
{
"src": "user_0",
"rel": f"llm_rel_{worker_idx}",
"dst": "user_1",
"confidence": 0.9,
}
]
}
return LLMResponse(content=json.dumps(payload), tool_calls=[])
if "SEED_TASK_EXPANSION_AGENT" in prompt:
worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
worker_idx = int(worker_match.group(1)) if worker_match else 0
budget_match = re.search(r"task_budget:\s*(\d+)", prompt)
task_budget = int(budget_match.group(1)) if budget_match else 1
tasks = []
for local_idx in range(max(1, task_budget)):
tasks.append(
{
"task_type": "identity_resolution",
"question": f"Which canonical user is tied to alias alias_seed_{worker_idx}_{local_idx}?",
"answer": "user_1",
"supporting_edges": [
{
"src": "alias_seed_0",
"rel": "alias_of",
"dst": "user_1",
"confidence": 0.95,
}
],
}
)
payload = {"tasks": tasks}
return LLMResponse(content=json.dumps(payload), tool_calls=[])
return LLMResponse(content="{}", tool_calls=[])
def test_generator_outputs():
gen = DatasetGenerator(EnvironmentConfig(n_users=20, seed=11))
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
tasks = gen.generate_tasks(graph, views, count=5)
assert len(graph.nodes) >= 20
assert len(views.microblog_posts) >= 20
assert len(tasks) == 5
def test_seeded_views_include_seeded_posts_and_threads():
from osint_env.config import clone_environment_config, load_seeding_config, load_shared_config
shared = load_shared_config("datasets/fixed_levels/shared_config_fixed_levels.json")
cfg = clone_environment_config(shared.environment)
cfg.seeding = load_seeding_config("datasets/fixed_levels/seed_fixed_levels.json")
cfg.llm.provider = "mock"
gen = DatasetGenerator(cfg)
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
seeded_post = next((post for post in views.microblog_posts if post["post_id"] == "post_midnight_manifest"), None)
seeded_thread = next((thread for thread in views.forum_threads if thread["thread_id"] == "thr_supply_leak"), None)
assert seeded_post is not None
assert "loc_dockyard17" in seeded_post["references"]
assert seeded_thread is not None
assert "org_northbridge_logistics" in seeded_thread["references"]
def test_graph_generation_uses_parallel_shared_context_workers():
cfg = EnvironmentConfig(n_users=12, seed=9)
cfg.seeding.llm_generate_remaining_graph = True
cfg.seeding.llm_generated_edge_budget = 4
cfg.seeding.llm_generate_remaining_tasks = False
cfg.seeding.llm_generation_parallel = True
cfg.seeding.llm_generation_workers = 3
cfg.seeding.llm_generation_retries = 1
cfg.seeding.allow_template_fallback_on_llm_failure = False
llm = SharedContextLLM()
gen = DatasetGenerator(cfg, llm=llm)
graph = gen.build_canonical_graph()
assert any(edge.rel.startswith("llm_rel_") for edge in graph.edges)
graph_prompts = [prompt for prompt in llm.prompts if "SEED_GRAPH_EXPANSION_AGENT" in prompt]
assert len(graph_prompts) >= 2
assert all("SHARED_CONTEXT" in prompt for prompt in graph_prompts)
def test_task_generation_uses_parallel_shared_context_workers():
cfg = EnvironmentConfig(n_users=12, seed=13)
cfg.seeding.llm_generate_remaining_graph = False
cfg.seeding.llm_generate_remaining_tasks = True
cfg.seeding.llm_generated_task_budget = 4
cfg.seeding.llm_generation_parallel = True
cfg.seeding.llm_generation_workers = 3
cfg.seeding.llm_generation_retries = 1
cfg.seeding.allow_template_fallback_on_llm_failure = False
llm = SharedContextLLM()
gen = DatasetGenerator(cfg, llm=llm)
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
tasks = gen.generate_tasks(graph, views, count=4)
assert len(tasks) == 4
assert any(task.metadata.get("shared_context") for task in tasks)
task_prompts = [prompt for prompt in llm.prompts if "SEED_TASK_EXPANSION_AGENT" in prompt]
assert len(task_prompts) >= 2
assert all("SHARED_CONTEXT" in prompt for prompt in task_prompts)
|