File size: 6,435 Bytes
7772fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Temporal Order Experiment
Tests position bias when information has inherent chronological ordering:
Chronological, Reverse-chronological, Scrambled.
"""
import logging
import os
import random
import re
import time
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)

EVENTS = [
    "the king issued a decree",
    "a comet appeared in the sky",
    "the bridge was completed",
    "a treaty was signed",
    "the harvest festival began",
    "a stranger arrived at the gates",
    "the library burned down",
    "a new star was discovered",
    "the river flooded the town",
    "the army marched north",
    "a peace envoy was sent",
    "the market was opened",
    "a plague swept the city",
    "the old temple was restored",
    "a fleet set sail",
    "the academy admitted its first students",
    "a rebellion broke out",
    "the queen gave birth to twins",
    "a dragon was spotted",
    "the great bell tolled",
]


def _make_timeline_chronological(n: int, target: str, target_year: int) -> str:
    """Events in chronological order with target embedded."""
    events = random.sample(EVENTS, min(n - 1, len(EVENTS)))
    while len(events) < n - 1:
        events.append(f"the people gathered for a ceremony")
    events.append(target)
    random.shuffle(events)  # Will be sorted by year
    # Insert target at target_year position
    return "\n".join(f"Year {1000+i}: {e}." for i, e in enumerate(events))


def _make_timeline_reverse(n: int, target: str, target_pos: float) -> str:
    events = random.sample(EVENTS, min(n - 1, len(EVENTS)))
    while len(events) < n - 1:
        events.append(f"the people gathered for a ceremony")
    idx = int(target_pos * len(events))
    events.insert(idx, target)
    # Reverse chronological: Year 2000 -> Year 1000
    return "\n".join(f"Year {2000-i}: {e}." for i, e in enumerate(events))


def _make_timeline_scrambled(n: int, target: str, target_pos: float) -> str:
    events = random.sample(EVENTS, min(n - 1, len(EVENTS)))
    while len(events) < n - 1:
        events.append(f"the people gathered for a ceremony")
    idx = int(target_pos * len(events))
    events.insert(idx, target)
    random.shuffle(events)
    years = random.sample(range(1000, 2000), len(events))
    return "\n".join(f"Year {y}: {e}." for y, e in zip(years, events))


def _run_timeline_ordering(
    model_name: str,
    num_events: int,
    num_examples: int,
    out_dir: str,
    order_type: str,
    target_year: int = None,
    depths: List[float] = None,
) -> Dict[str, Any]:
    ensure_dir(out_dir)
    if depths is None:
        depths = [0.0, 0.25, 0.5, 0.75, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[{order_type.upper()}] Depth {depth:.1%}")
        preds = []
        for i in tqdm(range(num_examples), desc=f"{order_type} {depth:.1%}", leave=False):
            target = "a golden statue was unveiled in the central square"

            if order_type == "chronological":
                timeline = _make_timeline_chronological(num_events, target, 1000 + int(depth * num_events))
                expected = 1000 + int(depth * num_events)
            elif order_type == "reverse":
                timeline = _make_timeline_reverse(num_events, target, depth)
                expected = 2000 - int(depth * num_events)
            elif order_type == "scrambled":
                timeline = _make_timeline_scrambled(num_events, target, depth)
                expected = None  # No clear expected year
            else:
                raise ValueError(f"Unknown order_type: {order_type}")

            prompt = (
                f"Read the following timeline of events.\n\n{timeline}\n\n"
                f"Question: In which year was a golden statue unveiled in the central square? "
                f"Answer with only the year number."
            )
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=15,
            )
            years = re.findall(r"\b\d{4}\b", ans)
            if expected is not None:
                correct = 1.0 if any(abs(int(y) - expected) < 5 for y in years) else 0.0
            else:
                # For scrambled, check if any year mentioned is reasonable
                correct = 1.0 if years else 0.0

            preds.append({
                "model_answer": ans,
                "correct": correct,
                "expected_year": expected,
                "depth": depth,
                "order_type": order_type,
            })

        save_jsonl(os.path.join(out_dir, f"{order_type}_depth_{depth}.jsonl"), preds)
        acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[{order_type.upper()}] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": f"temporal_{order_type}",
        "num_events": num_events,
        "num_examples": num_examples,
        "order_type": order_type,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "time_minutes": (time.time() - start) / 60,
    }
    save_json(os.path.join(out_dir, f"{order_type}_summary.json"), summary)
    return summary


def run_all_temporal(
    model_name: str,
    num_events: int,
    num_examples: int,
    out_dir: str,
) -> Dict[str, Any]:
    """Run all three temporal ordering conditions."""
    ensure_dir(out_dir)
    all_results = {}

    for order in ["chronological", "reverse", "scrambled"]:
        logger.info(f"\n--- Temporal Ordering: {order.upper()} ---")
        all_results[order] = _run_timeline_ordering(
            model_name, num_events, num_examples,
            os.path.join(out_dir, order), order,
        )

    save_json(os.path.join(out_dir, "temporal_master_summary.json"), all_results)

    # Compare
    logger.info("\n--- Temporal Ordering Comparison ---")
    for order, res in all_results.items():
        depths = list(res["depths"].keys())
        accs = list(res["depths"].values())
        if len(accs) >= 3:
            mid_idx = len(accs) // 2
            pbi = (accs[0] + accs[-1]) / 2 - accs[mid_idx]
            logger.info(f"  {order:15s} PBI={pbi:+.3f}")

    return all_results