Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| from threading import Lock | |
| from osint_env.data.generator import DatasetGenerator | |
| from osint_env.domain.models import EnvironmentConfig | |
| from osint_env.llm.interface import LLMResponse | |
| class SharedContextLLM: | |
| def __init__(self): | |
| self.prompts: list[str] = [] | |
| self._lock = Lock() | |
| def generate(self, messages, tools): | |
| prompt = str(messages[0].get("content", "")) if messages else "" | |
| with self._lock: | |
| self.prompts.append(prompt) | |
| if "SEED_GRAPH_EXPANSION_AGENT" in prompt: | |
| worker_match = re.search(r"worker_id:\s*(\d+)", prompt) | |
| worker_idx = int(worker_match.group(1)) if worker_match else 0 | |
| payload = { | |
| "edges": [ | |
| { | |
| "src": "user_0", | |
| "rel": f"llm_rel_{worker_idx}", | |
| "dst": "user_1", | |
| "confidence": 0.9, | |
| } | |
| ] | |
| } | |
| return LLMResponse(content=json.dumps(payload), tool_calls=[]) | |
| if "SEED_TASK_EXPANSION_AGENT" in prompt: | |
| worker_match = re.search(r"worker_id:\s*(\d+)", prompt) | |
| worker_idx = int(worker_match.group(1)) if worker_match else 0 | |
| budget_match = re.search(r"task_budget:\s*(\d+)", prompt) | |
| task_budget = int(budget_match.group(1)) if budget_match else 1 | |
| tasks = [] | |
| for local_idx in range(max(1, task_budget)): | |
| tasks.append( | |
| { | |
| "task_type": "identity_resolution", | |
| "question": f"Which canonical user is tied to alias alias_seed_{worker_idx}_{local_idx}?", | |
| "answer": "user_1", | |
| "supporting_edges": [ | |
| { | |
| "src": "alias_seed_0", | |
| "rel": "alias_of", | |
| "dst": "user_1", | |
| "confidence": 0.95, | |
| } | |
| ], | |
| } | |
| ) | |
| payload = {"tasks": tasks} | |
| return LLMResponse(content=json.dumps(payload), tool_calls=[]) | |
| return LLMResponse(content="{}", tool_calls=[]) | |
| def test_generator_outputs(): | |
| gen = DatasetGenerator(EnvironmentConfig(n_users=20, seed=11)) | |
| graph = gen.build_canonical_graph() | |
| views = gen.build_platform_views(graph) | |
| tasks = gen.generate_tasks(graph, views, count=5) | |
| assert len(graph.nodes) >= 20 | |
| assert len(views.microblog_posts) >= 20 | |
| assert len(tasks) == 5 | |
| def test_seeded_views_include_seeded_posts_and_threads(): | |
| from osint_env.config import clone_environment_config, load_seeding_config, load_shared_config | |
| shared = load_shared_config("datasets/fixed_levels/shared_config_fixed_levels.json") | |
| cfg = clone_environment_config(shared.environment) | |
| cfg.seeding = load_seeding_config("datasets/fixed_levels/seed_fixed_levels.json") | |
| cfg.llm.provider = "mock" | |
| gen = DatasetGenerator(cfg) | |
| graph = gen.build_canonical_graph() | |
| views = gen.build_platform_views(graph) | |
| seeded_post = next((post for post in views.microblog_posts if post["post_id"] == "post_midnight_manifest"), None) | |
| seeded_thread = next((thread for thread in views.forum_threads if thread["thread_id"] == "thr_supply_leak"), None) | |
| assert seeded_post is not None | |
| assert "loc_dockyard17" in seeded_post["references"] | |
| assert seeded_thread is not None | |
| assert "org_northbridge_logistics" in seeded_thread["references"] | |
| def test_graph_generation_uses_parallel_shared_context_workers(): | |
| cfg = EnvironmentConfig(n_users=12, seed=9) | |
| cfg.seeding.llm_generate_remaining_graph = True | |
| cfg.seeding.llm_generated_edge_budget = 4 | |
| cfg.seeding.llm_generate_remaining_tasks = False | |
| cfg.seeding.llm_generation_parallel = True | |
| cfg.seeding.llm_generation_workers = 3 | |
| cfg.seeding.llm_generation_retries = 1 | |
| cfg.seeding.allow_template_fallback_on_llm_failure = False | |
| llm = SharedContextLLM() | |
| gen = DatasetGenerator(cfg, llm=llm) | |
| graph = gen.build_canonical_graph() | |
| assert any(edge.rel.startswith("llm_rel_") for edge in graph.edges) | |
| graph_prompts = [prompt for prompt in llm.prompts if "SEED_GRAPH_EXPANSION_AGENT" in prompt] | |
| assert len(graph_prompts) >= 2 | |
| assert all("SHARED_CONTEXT" in prompt for prompt in graph_prompts) | |
| def test_task_generation_uses_parallel_shared_context_workers(): | |
| cfg = EnvironmentConfig(n_users=12, seed=13) | |
| cfg.seeding.llm_generate_remaining_graph = False | |
| cfg.seeding.llm_generate_remaining_tasks = True | |
| cfg.seeding.llm_generated_task_budget = 4 | |
| cfg.seeding.llm_generation_parallel = True | |
| cfg.seeding.llm_generation_workers = 3 | |
| cfg.seeding.llm_generation_retries = 1 | |
| cfg.seeding.allow_template_fallback_on_llm_failure = False | |
| llm = SharedContextLLM() | |
| gen = DatasetGenerator(cfg, llm=llm) | |
| graph = gen.build_canonical_graph() | |
| views = gen.build_platform_views(graph) | |
| tasks = gen.generate_tasks(graph, views, count=4) | |
| assert len(tasks) == 4 | |
| assert any(task.metadata.get("shared_context") for task in tasks) | |
| task_prompts = [prompt for prompt in llm.prompts if "SEED_TASK_EXPANSION_AGENT" in prompt] | |
| assert len(task_prompts) >= 2 | |
| assert all("SHARED_CONTEXT" in prompt for prompt in task_prompts) | |