File size: 2,436 Bytes
d814291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from pathlib import Path

from osint_env.data.generator import DatasetGenerator
from osint_env.domain.models import EnvironmentConfig


def _write_metaqa_fixture(root: Path) -> None:
    root.mkdir(parents=True, exist_ok=True)
    (root / "kb.txt").write_text(
        "\n".join(
            [
                "Movie A|starred_actors|Actor X",
                "Movie B|starred_actors|Actor X",
                "Movie A|directed_by|Director D",
                "Movie C|directed_by|Director D",
                "Movie C|release_year|2002",
            ]
        ),
        encoding="utf-8",
    )

    rows = {
        "1-hop": ("what movies did [Actor X] act in\tMovie A|Movie B\n", "actor_to_movie\n"),
        "2-hop": ("which films share the director of [Movie A]\tMovie C\n", "movie_to_director_to_movie\n"),
        "3-hop": (
            "which release year corresponds to films with same director as [Movie A]\t2002\n",
            "movie_to_director_to_movie_to_year\n",
        ),
    }

    for hop, (qa_line, qtype_line) in rows.items():
        qa_dir = root / hop / "vanilla"
        qa_dir.mkdir(parents=True, exist_ok=True)
        (qa_dir / "qa_train.txt").write_text(qa_line, encoding="utf-8")
        (root / hop / "qa_train_qtype.txt").write_text(qtype_line, encoding="utf-8")


def test_metaqa_mode_builds_graph_and_hop_tasks(tmp_path: Path):
    metaqa_root = tmp_path / "metaQA"
    _write_metaqa_fixture(metaqa_root)

    cfg = EnvironmentConfig(
        seed=5,
        dataset_mode="metaqa",
        metaqa_root=str(metaqa_root),
        metaqa_variant="vanilla",
        metaqa_hops=["1-hop", "2-hop", "3-hop"],
        metaqa_splits=["train"],
    )

    gen = DatasetGenerator(cfg)
    graph = gen.build_canonical_graph()
    views = gen.build_platform_views(graph)
    tasks = gen.generate_tasks(graph, views, count=24)

    assert len(graph.nodes) >= 5
    assert any(edge.rel == "directed_by" for edge in graph.edges)
    assert any(post["post_id"].startswith("post_metaqa_") for post in views.microblog_posts)
    assert any(profile["user_id"] == "Actor X" for profile in views.profiles)

    hop_labels = {str(task.metadata.get("hop", "")) for task in tasks}
    difficulties = {str(task.metadata.get("difficulty", "")) for task in tasks}

    assert hop_labels == {"1-hop", "2-hop", "3-hop"}
    assert difficulties == {"easy", "medium", "hard"}
    assert all(task.supporting_edges for task in tasks)