File size: 5,457 Bytes
d6fbf54
 
 
 
aa360de
 
d6fbf54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa360de
 
 
 
 
 
 
 
9af411f
aa360de
d6fbf54
 
9af411f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6fbf54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import re
from threading import Lock

from osint_env.data.generator import DatasetGenerator
from osint_env.domain.models import EnvironmentConfig
from osint_env.llm.interface import LLMResponse


class SharedContextLLM:
    def __init__(self):
        self.prompts: list[str] = []
        self._lock = Lock()

    def generate(self, messages, tools):
        prompt = str(messages[0].get("content", "")) if messages else ""
        with self._lock:
            self.prompts.append(prompt)

        if "SEED_GRAPH_EXPANSION_AGENT" in prompt:
            worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
            worker_idx = int(worker_match.group(1)) if worker_match else 0
            payload = {
                "edges": [
                    {
                        "src": "user_0",
                        "rel": f"llm_rel_{worker_idx}",
                        "dst": "user_1",
                        "confidence": 0.9,
                    }
                ]
            }
            return LLMResponse(content=json.dumps(payload), tool_calls=[])

        if "SEED_TASK_EXPANSION_AGENT" in prompt:
            worker_match = re.search(r"worker_id:\s*(\d+)", prompt)
            worker_idx = int(worker_match.group(1)) if worker_match else 0
            budget_match = re.search(r"task_budget:\s*(\d+)", prompt)
            task_budget = int(budget_match.group(1)) if budget_match else 1
            tasks = []
            for local_idx in range(max(1, task_budget)):
                tasks.append(
                    {
                        "task_type": "identity_resolution",
                        "question": f"Which canonical user is tied to alias alias_seed_{worker_idx}_{local_idx}?",
                        "answer": "user_1",
                        "supporting_edges": [
                            {
                                "src": "alias_seed_0",
                                "rel": "alias_of",
                                "dst": "user_1",
                                "confidence": 0.95,
                            }
                        ],
                    }
                )
            payload = {"tasks": tasks}
            return LLMResponse(content=json.dumps(payload), tool_calls=[])

        return LLMResponse(content="{}", tool_calls=[])


def test_generator_outputs():
    gen = DatasetGenerator(EnvironmentConfig(n_users=20, seed=11))
    graph = gen.build_canonical_graph()
    views = gen.build_platform_views(graph)
    tasks = gen.generate_tasks(graph, views, count=5)
    assert len(graph.nodes) >= 20
    assert len(views.microblog_posts) >= 20
    assert len(tasks) == 5


def test_seeded_views_include_seeded_posts_and_threads():
    from osint_env.config import clone_environment_config, load_seeding_config, load_shared_config

    shared = load_shared_config("datasets/fixed_levels/shared_config_fixed_levels.json")
    cfg = clone_environment_config(shared.environment)
    cfg.seeding = load_seeding_config("datasets/fixed_levels/seed_fixed_levels.json")
    cfg.llm.provider = "mock"

    gen = DatasetGenerator(cfg)
    graph = gen.build_canonical_graph()
    views = gen.build_platform_views(graph)

    seeded_post = next((post for post in views.microblog_posts if post["post_id"] == "post_midnight_manifest"), None)
    seeded_thread = next((thread for thread in views.forum_threads if thread["thread_id"] == "thr_supply_leak"), None)

    assert seeded_post is not None
    assert "loc_dockyard17" in seeded_post["references"]
    assert seeded_thread is not None
    assert "org_northbridge_logistics" in seeded_thread["references"]


def test_graph_generation_uses_parallel_shared_context_workers():
    cfg = EnvironmentConfig(n_users=12, seed=9)
    cfg.seeding.llm_generate_remaining_graph = True
    cfg.seeding.llm_generated_edge_budget = 4
    cfg.seeding.llm_generate_remaining_tasks = False
    cfg.seeding.llm_generation_parallel = True
    cfg.seeding.llm_generation_workers = 3
    cfg.seeding.llm_generation_retries = 1
    cfg.seeding.allow_template_fallback_on_llm_failure = False

    llm = SharedContextLLM()
    gen = DatasetGenerator(cfg, llm=llm)
    graph = gen.build_canonical_graph()

    assert any(edge.rel.startswith("llm_rel_") for edge in graph.edges)
    graph_prompts = [prompt for prompt in llm.prompts if "SEED_GRAPH_EXPANSION_AGENT" in prompt]
    assert len(graph_prompts) >= 2
    assert all("SHARED_CONTEXT" in prompt for prompt in graph_prompts)


def test_task_generation_uses_parallel_shared_context_workers():
    cfg = EnvironmentConfig(n_users=12, seed=13)
    cfg.seeding.llm_generate_remaining_graph = False
    cfg.seeding.llm_generate_remaining_tasks = True
    cfg.seeding.llm_generated_task_budget = 4
    cfg.seeding.llm_generation_parallel = True
    cfg.seeding.llm_generation_workers = 3
    cfg.seeding.llm_generation_retries = 1
    cfg.seeding.allow_template_fallback_on_llm_failure = False

    llm = SharedContextLLM()
    gen = DatasetGenerator(cfg, llm=llm)
    graph = gen.build_canonical_graph()
    views = gen.build_platform_views(graph)
    tasks = gen.generate_tasks(graph, views, count=4)

    assert len(tasks) == 4
    assert any(task.metadata.get("shared_context") for task in tasks)
    task_prompts = [prompt for prompt in llm.prompts if "SEED_TASK_EXPANSION_AGENT" in prompt]
    assert len(task_prompts) >= 2
    assert all("SHARED_CONTEXT" in prompt for prompt in task_prompts)