OSINT / tests /test_generator.py
siddeshwar-kagatikar
Fix seeded context retrieval and latest Space dashboard
9af411f
import json
import re
from threading import Lock
from osint_env.data.generator import DatasetGenerator
from osint_env.domain.models import EnvironmentConfig
from osint_env.llm.interface import LLMResponse
class SharedContextLLM:
def __init__(self):
self.prompts: list[str] = []
self._lock = Lock()
def generate(self, messages, tools):
prompt = str(messages[0].get("content", "")) if messages else ""
with self._lock:
self.prompts.append(prompt)
if "SEED_GRAPH_EXPANSION_AGENT" in prompt:
worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
worker_idx = int(worker_match.group(1)) if worker_match else 0
payload = {
"edges": [
{
"src": "user_0",
"rel": f"llm_rel_{worker_idx}",
"dst": "user_1",
"confidence": 0.9,
}
]
}
return LLMResponse(content=json.dumps(payload), tool_calls=[])
if "SEED_TASK_EXPANSION_AGENT" in prompt:
worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
worker_idx = int(worker_match.group(1)) if worker_match else 0
budget_match = re.search(r"task_budget:\s*(\d+)", prompt)
task_budget = int(budget_match.group(1)) if budget_match else 1
tasks = []
for local_idx in range(max(1, task_budget)):
tasks.append(
{
"task_type": "identity_resolution",
"question": f"Which canonical user is tied to alias alias_seed_{worker_idx}_{local_idx}?",
"answer": "user_1",
"supporting_edges": [
{
"src": "alias_seed_0",
"rel": "alias_of",
"dst": "user_1",
"confidence": 0.95,
}
],
}
)
payload = {"tasks": tasks}
return LLMResponse(content=json.dumps(payload), tool_calls=[])
return LLMResponse(content="{}", tool_calls=[])
def test_generator_outputs():
gen = DatasetGenerator(EnvironmentConfig(n_users=20, seed=11))
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
tasks = gen.generate_tasks(graph, views, count=5)
assert len(graph.nodes) >= 20
assert len(views.microblog_posts) >= 20
assert len(tasks) == 5
def test_seeded_views_include_seeded_posts_and_threads():
from osint_env.config import clone_environment_config, load_seeding_config, load_shared_config
shared = load_shared_config("datasets/fixed_levels/shared_config_fixed_levels.json")
cfg = clone_environment_config(shared.environment)
cfg.seeding = load_seeding_config("datasets/fixed_levels/seed_fixed_levels.json")
cfg.llm.provider = "mock"
gen = DatasetGenerator(cfg)
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
seeded_post = next((post for post in views.microblog_posts if post["post_id"] == "post_midnight_manifest"), None)
seeded_thread = next((thread for thread in views.forum_threads if thread["thread_id"] == "thr_supply_leak"), None)
assert seeded_post is not None
assert "loc_dockyard17" in seeded_post["references"]
assert seeded_thread is not None
assert "org_northbridge_logistics" in seeded_thread["references"]
def test_graph_generation_uses_parallel_shared_context_workers():
cfg = EnvironmentConfig(n_users=12, seed=9)
cfg.seeding.llm_generate_remaining_graph = True
cfg.seeding.llm_generated_edge_budget = 4
cfg.seeding.llm_generate_remaining_tasks = False
cfg.seeding.llm_generation_parallel = True
cfg.seeding.llm_generation_workers = 3
cfg.seeding.llm_generation_retries = 1
cfg.seeding.allow_template_fallback_on_llm_failure = False
llm = SharedContextLLM()
gen = DatasetGenerator(cfg, llm=llm)
graph = gen.build_canonical_graph()
assert any(edge.rel.startswith("llm_rel_") for edge in graph.edges)
graph_prompts = [prompt for prompt in llm.prompts if "SEED_GRAPH_EXPANSION_AGENT" in prompt]
assert len(graph_prompts) >= 2
assert all("SHARED_CONTEXT" in prompt for prompt in graph_prompts)
def test_task_generation_uses_parallel_shared_context_workers():
cfg = EnvironmentConfig(n_users=12, seed=13)
cfg.seeding.llm_generate_remaining_graph = False
cfg.seeding.llm_generate_remaining_tasks = True
cfg.seeding.llm_generated_task_budget = 4
cfg.seeding.llm_generation_parallel = True
cfg.seeding.llm_generation_workers = 3
cfg.seeding.llm_generation_retries = 1
cfg.seeding.allow_template_fallback_on_llm_failure = False
llm = SharedContextLLM()
gen = DatasetGenerator(cfg, llm=llm)
graph = gen.build_canonical_graph()
views = gen.build_platform_views(graph)
tasks = gen.generate_tasks(graph, views, count=4)
assert len(tasks) == 4
assert any(task.metadata.get("shared_context") for task in tasks)
task_prompts = [prompt for prompt in llm.prompts if "SEED_TASK_EXPANSION_AGENT" in prompt]
assert len(task_prompts) >= 2
assert all("SHARED_CONTEXT" in prompt for prompt in task_prompts)