File size: 4,270 Bytes
7772fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """
Recency × Position Interaction
Tests whether recency bias (preferring recent events) interacts with position bias.
Same events placed at different positions with different timestamps.
"""
import logging
import os
import random
import re
import time
from typing import List, Dict, Any
from tqdm import tqdm
from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json
logger = logging.getLogger(__name__)
EVENTS = [
"the king issued a decree",
"a comet appeared in the sky",
"the bridge was completed",
"a treaty was signed",
"the harvest festival began",
"a stranger arrived at the gates",
"the library burned down",
"a new star was discovered",
"the river flooded the town",
"the army marched north",
"a peace envoy was sent",
"the market was opened",
"a plague swept the city",
"the old temple was restored",
"a fleet set sail",
"the academy admitted its first students",
"a rebellion broke out",
"the queen gave birth to twins",
"a dragon was spotted",
"the great bell tolled",
]
def run_recency_interaction(
model_name: str,
num_events: int,
num_examples: int,
out_dir: str,
depths: List[float] = None,
) -> Dict[str, Any]:
"""Test recency bias: older event at early position, newer event at late position."""
ensure_dir(out_dir)
if depths is None:
depths = [0.0, 0.25, 0.5, 0.75, 1.0]
results = {}
start = time.time()
for depth in depths:
logger.info(f"[RECENCY] Depth {depth:.1%}")
preds = []
for _ in tqdm(range(num_examples), desc=f"Recency {depth:.1%}", leave=False):
events = random.sample(EVENTS, min(num_events, len(EVENTS)))
while len(events) < num_events:
events.append("the people gathered for a ceremony")
# Target event with year matching position
# Early position = old year, Late position = new year
target = "a golden statue was unveiled"
target_year = 1000 + int(depth * 1000) # 1000 -> 2000
idx = int(depth * len(events))
events.insert(idx, f"Year {target_year}: {target}")
# All other events get mixed years
other_years = random.sample(range(1000, 2000), len(events) - 1)
non_target = [e for e in events if target not in e]
target_events = [e for e in events if target in e]
others_with_years = [
f"Year {y}: {e}." for y, e in zip(other_years, non_target)
]
target_with_year = [f"{e}." for e in target_events]
# Combine
timeline_lines = others_with_years[:idx] + target_with_year + others_with_years[idx:]
timeline = "\n".join(timeline_lines)
prompt = (
f"Read the following timeline.\n\n{timeline}\n\n"
f"Question: In which year was a golden statue unveiled? "
f"Answer with only the year number."
)
ans = generate_text(
[{"role": "user", "content": prompt}],
model_name=model_name,
max_new_tokens=15,
)
years = re.findall(r"\b\d{4}\b", ans)
correct = 1.0 if any(abs(int(y) - target_year) < 5 for y in years) else 0.0
preds.append({
"model_answer": ans,
"correct": correct,
"expected_year": target_year,
"depth": depth,
})
save_jsonl(os.path.join(out_dir, f"recency_depth_{depth}.jsonl"), preds)
acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
results[depth] = {"accuracy": acc, "predictions": preds}
logger.info(f"[RECENCY] Depth {depth:.1%}: acc={acc:.3f}")
summary = {
"experiment": "recency_interaction",
"num_events": num_events,
"num_examples": num_examples,
"depths": {str(d): results[d]["accuracy"] for d in depths},
"time_minutes": (time.time() - start) / 60,
}
save_json(os.path.join(out_dir, "recency_summary.json"), summary)
logger.info(f"[RECENCY] Time={(time.time()-start)/60:.1f} min")
return summary
|