File size: 4,270 Bytes
7772fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Recency × Position Interaction
Tests whether recency bias (preferring recent events) interacts with position bias.
Same events placed at different positions with different timestamps.
"""
import logging
import os
import random
import re
import time
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)

EVENTS = [
    "the king issued a decree",
    "a comet appeared in the sky",
    "the bridge was completed",
    "a treaty was signed",
    "the harvest festival began",
    "a stranger arrived at the gates",
    "the library burned down",
    "a new star was discovered",
    "the river flooded the town",
    "the army marched north",
    "a peace envoy was sent",
    "the market was opened",
    "a plague swept the city",
    "the old temple was restored",
    "a fleet set sail",
    "the academy admitted its first students",
    "a rebellion broke out",
    "the queen gave birth to twins",
    "a dragon was spotted",
    "the great bell tolled",
]


def run_recency_interaction(
    model_name: str,
    num_events: int,
    num_examples: int,
    out_dir: str,
    depths: List[float] = None,
) -> Dict[str, Any]:
    """Test recency bias: older event at early position, newer event at late position."""
    ensure_dir(out_dir)
    if depths is None:
        depths = [0.0, 0.25, 0.5, 0.75, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[RECENCY] Depth {depth:.1%}")
        preds = []
        for _ in tqdm(range(num_examples), desc=f"Recency {depth:.1%}", leave=False):
            events = random.sample(EVENTS, min(num_events, len(EVENTS)))
            while len(events) < num_events:
                events.append("the people gathered for a ceremony")

            # Target event with year matching position
            # Early position = old year, Late position = new year
            target = "a golden statue was unveiled"
            target_year = 1000 + int(depth * 1000)  # 1000 -> 2000
            idx = int(depth * len(events))
            events.insert(idx, f"Year {target_year}: {target}")

            # All other events get mixed years
            other_years = random.sample(range(1000, 2000), len(events) - 1)
            non_target = [e for e in events if target not in e]
            target_events = [e for e in events if target in e]
            others_with_years = [
                f"Year {y}: {e}." for y, e in zip(other_years, non_target)
            ]
            target_with_year = [f"{e}." for e in target_events]

            # Combine
            timeline_lines = others_with_years[:idx] + target_with_year + others_with_years[idx:]
            timeline = "\n".join(timeline_lines)

            prompt = (
                f"Read the following timeline.\n\n{timeline}\n\n"
                f"Question: In which year was a golden statue unveiled? "
                f"Answer with only the year number."
            )
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=15,
            )
            years = re.findall(r"\b\d{4}\b", ans)
            correct = 1.0 if any(abs(int(y) - target_year) < 5 for y in years) else 0.0

            preds.append({
                "model_answer": ans,
                "correct": correct,
                "expected_year": target_year,
                "depth": depth,
            })

        save_jsonl(os.path.join(out_dir, f"recency_depth_{depth}.jsonl"), preds)
        acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[RECENCY] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": "recency_interaction",
        "num_events": num_events,
        "num_examples": num_examples,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "time_minutes": (time.time() - start) / 60,
    }
    save_json(os.path.join(out_dir, "recency_summary.json"), summary)
    logger.info(f"[RECENCY] Time={(time.time()-start)/60:.1f} min")
    return summary