File size: 6,435 Bytes
7772fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | """
Temporal Order Experiment
Tests position bias when information has inherent chronological ordering:
Chronological, Reverse-chronological, Scrambled.
"""
import logging
import os
import random
import re
import time
from typing import List, Dict, Any
from tqdm import tqdm
from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json
logger = logging.getLogger(__name__)
EVENTS = [
"the king issued a decree",
"a comet appeared in the sky",
"the bridge was completed",
"a treaty was signed",
"the harvest festival began",
"a stranger arrived at the gates",
"the library burned down",
"a new star was discovered",
"the river flooded the town",
"the army marched north",
"a peace envoy was sent",
"the market was opened",
"a plague swept the city",
"the old temple was restored",
"a fleet set sail",
"the academy admitted its first students",
"a rebellion broke out",
"the queen gave birth to twins",
"a dragon was spotted",
"the great bell tolled",
]
def _make_timeline_chronological(n: int, target: str, target_year: int) -> str:
"""Events in chronological order with target embedded."""
events = random.sample(EVENTS, min(n - 1, len(EVENTS)))
while len(events) < n - 1:
events.append(f"the people gathered for a ceremony")
events.append(target)
random.shuffle(events) # Will be sorted by year
# Insert target at target_year position
return "\n".join(f"Year {1000+i}: {e}." for i, e in enumerate(events))
def _make_timeline_reverse(n: int, target: str, target_pos: float) -> str:
events = random.sample(EVENTS, min(n - 1, len(EVENTS)))
while len(events) < n - 1:
events.append(f"the people gathered for a ceremony")
idx = int(target_pos * len(events))
events.insert(idx, target)
# Reverse chronological: Year 2000 -> Year 1000
return "\n".join(f"Year {2000-i}: {e}." for i, e in enumerate(events))
def _make_timeline_scrambled(n: int, target: str, target_pos: float) -> str:
events = random.sample(EVENTS, min(n - 1, len(EVENTS)))
while len(events) < n - 1:
events.append(f"the people gathered for a ceremony")
idx = int(target_pos * len(events))
events.insert(idx, target)
random.shuffle(events)
years = random.sample(range(1000, 2000), len(events))
return "\n".join(f"Year {y}: {e}." for y, e in zip(years, events))
def _run_timeline_ordering(
model_name: str,
num_events: int,
num_examples: int,
out_dir: str,
order_type: str,
target_year: int = None,
depths: List[float] = None,
) -> Dict[str, Any]:
ensure_dir(out_dir)
if depths is None:
depths = [0.0, 0.25, 0.5, 0.75, 1.0]
results = {}
start = time.time()
for depth in depths:
logger.info(f"[{order_type.upper()}] Depth {depth:.1%}")
preds = []
for i in tqdm(range(num_examples), desc=f"{order_type} {depth:.1%}", leave=False):
target = "a golden statue was unveiled in the central square"
if order_type == "chronological":
timeline = _make_timeline_chronological(num_events, target, 1000 + int(depth * num_events))
expected = 1000 + int(depth * num_events)
elif order_type == "reverse":
timeline = _make_timeline_reverse(num_events, target, depth)
expected = 2000 - int(depth * num_events)
elif order_type == "scrambled":
timeline = _make_timeline_scrambled(num_events, target, depth)
expected = None # No clear expected year
else:
raise ValueError(f"Unknown order_type: {order_type}")
prompt = (
f"Read the following timeline of events.\n\n{timeline}\n\n"
f"Question: In which year was a golden statue unveiled in the central square? "
f"Answer with only the year number."
)
ans = generate_text(
[{"role": "user", "content": prompt}],
model_name=model_name,
max_new_tokens=15,
)
years = re.findall(r"\b\d{4}\b", ans)
if expected is not None:
correct = 1.0 if any(abs(int(y) - expected) < 5 for y in years) else 0.0
else:
# For scrambled, check if any year mentioned is reasonable
correct = 1.0 if years else 0.0
preds.append({
"model_answer": ans,
"correct": correct,
"expected_year": expected,
"depth": depth,
"order_type": order_type,
})
save_jsonl(os.path.join(out_dir, f"{order_type}_depth_{depth}.jsonl"), preds)
acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
results[depth] = {"accuracy": acc, "predictions": preds}
logger.info(f"[{order_type.upper()}] Depth {depth:.1%}: acc={acc:.3f}")
summary = {
"experiment": f"temporal_{order_type}",
"num_events": num_events,
"num_examples": num_examples,
"order_type": order_type,
"depths": {str(d): results[d]["accuracy"] for d in depths},
"time_minutes": (time.time() - start) / 60,
}
save_json(os.path.join(out_dir, f"{order_type}_summary.json"), summary)
return summary
def run_all_temporal(
model_name: str,
num_events: int,
num_examples: int,
out_dir: str,
) -> Dict[str, Any]:
"""Run all three temporal ordering conditions."""
ensure_dir(out_dir)
all_results = {}
for order in ["chronological", "reverse", "scrambled"]:
logger.info(f"\n--- Temporal Ordering: {order.upper()} ---")
all_results[order] = _run_timeline_ordering(
model_name, num_events, num_examples,
os.path.join(out_dir, order), order,
)
save_json(os.path.join(out_dir, "temporal_master_summary.json"), all_results)
# Compare
logger.info("\n--- Temporal Ordering Comparison ---")
for order, res in all_results.items():
depths = list(res["depths"].keys())
accs = list(res["depths"].values())
if len(accs) >= 3:
mid_idx = len(accs) // 2
pbi = (accs[0] + accs[-1]) / 2 - accs[mid_idx]
logger.info(f" {order:15s} PBI={pbi:+.3f}")
return all_results
|