| """ |
| Temporal Order Experiment |
| Tests position bias when information has inherent chronological ordering: |
| Chronological, Reverse-chronological, Scrambled. |
| """ |
| import logging |
| import os |
| import random |
| import re |
| import time |
| from typing import List, Dict, Any |
|
|
| from tqdm import tqdm |
|
|
| from src.generator import generate_text |
| from src.utils import ensure_dir, save_jsonl, save_json |
|
|
| logger = logging.getLogger(__name__) |
|
|
| EVENTS = [ |
| "the king issued a decree", |
| "a comet appeared in the sky", |
| "the bridge was completed", |
| "a treaty was signed", |
| "the harvest festival began", |
| "a stranger arrived at the gates", |
| "the library burned down", |
| "a new star was discovered", |
| "the river flooded the town", |
| "the army marched north", |
| "a peace envoy was sent", |
| "the market was opened", |
| "a plague swept the city", |
| "the old temple was restored", |
| "a fleet set sail", |
| "the academy admitted its first students", |
| "a rebellion broke out", |
| "the queen gave birth to twins", |
| "a dragon was spotted", |
| "the great bell tolled", |
| ] |
|
|
|
|
| def _make_timeline_chronological(n: int, target: str, target_year: int) -> str: |
| """Events in chronological order with target embedded.""" |
| events = random.sample(EVENTS, min(n - 1, len(EVENTS))) |
| while len(events) < n - 1: |
| events.append(f"the people gathered for a ceremony") |
| events.append(target) |
| random.shuffle(events) |
| |
| return "\n".join(f"Year {1000+i}: {e}." for i, e in enumerate(events)) |
|
|
|
|
| def _make_timeline_reverse(n: int, target: str, target_pos: float) -> str: |
| events = random.sample(EVENTS, min(n - 1, len(EVENTS))) |
| while len(events) < n - 1: |
| events.append(f"the people gathered for a ceremony") |
| idx = int(target_pos * len(events)) |
| events.insert(idx, target) |
| |
| return "\n".join(f"Year {2000-i}: {e}." for i, e in enumerate(events)) |
|
|
|
|
| def _make_timeline_scrambled(n: int, target: str, target_pos: float) -> str: |
| events = random.sample(EVENTS, min(n - 1, len(EVENTS))) |
| while len(events) < n - 1: |
| events.append(f"the people gathered for a ceremony") |
| idx = int(target_pos * len(events)) |
| events.insert(idx, target) |
| random.shuffle(events) |
| years = random.sample(range(1000, 2000), len(events)) |
| return "\n".join(f"Year {y}: {e}." for y, e in zip(years, events)) |
|
|
|
|
| def _run_timeline_ordering( |
| model_name: str, |
| num_events: int, |
| num_examples: int, |
| out_dir: str, |
| order_type: str, |
| target_year: int = None, |
| depths: List[float] = None, |
| ) -> Dict[str, Any]: |
| ensure_dir(out_dir) |
| if depths is None: |
| depths = [0.0, 0.25, 0.5, 0.75, 1.0] |
|
|
| results = {} |
| start = time.time() |
|
|
| for depth in depths: |
| logger.info(f"[{order_type.upper()}] Depth {depth:.1%}") |
| preds = [] |
| for i in tqdm(range(num_examples), desc=f"{order_type} {depth:.1%}", leave=False): |
| target = "a golden statue was unveiled in the central square" |
|
|
| if order_type == "chronological": |
| timeline = _make_timeline_chronological(num_events, target, 1000 + int(depth * num_events)) |
| expected = 1000 + int(depth * num_events) |
| elif order_type == "reverse": |
| timeline = _make_timeline_reverse(num_events, target, depth) |
| expected = 2000 - int(depth * num_events) |
| elif order_type == "scrambled": |
| timeline = _make_timeline_scrambled(num_events, target, depth) |
| expected = None |
| else: |
| raise ValueError(f"Unknown order_type: {order_type}") |
|
|
| prompt = ( |
| f"Read the following timeline of events.\n\n{timeline}\n\n" |
| f"Question: In which year was a golden statue unveiled in the central square? " |
| f"Answer with only the year number." |
| ) |
| ans = generate_text( |
| [{"role": "user", "content": prompt}], |
| model_name=model_name, |
| max_new_tokens=15, |
| ) |
| years = re.findall(r"\b\d{4}\b", ans) |
| if expected is not None: |
| correct = 1.0 if any(abs(int(y) - expected) < 5 for y in years) else 0.0 |
| else: |
| |
| correct = 1.0 if years else 0.0 |
|
|
| preds.append({ |
| "model_answer": ans, |
| "correct": correct, |
| "expected_year": expected, |
| "depth": depth, |
| "order_type": order_type, |
| }) |
|
|
| save_jsonl(os.path.join(out_dir, f"{order_type}_depth_{depth}.jsonl"), preds) |
| acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0 |
| results[depth] = {"accuracy": acc, "predictions": preds} |
| logger.info(f"[{order_type.upper()}] Depth {depth:.1%}: acc={acc:.3f}") |
|
|
| summary = { |
| "experiment": f"temporal_{order_type}", |
| "num_events": num_events, |
| "num_examples": num_examples, |
| "order_type": order_type, |
| "depths": {str(d): results[d]["accuracy"] for d in depths}, |
| "time_minutes": (time.time() - start) / 60, |
| } |
| save_json(os.path.join(out_dir, f"{order_type}_summary.json"), summary) |
| return summary |
|
|
|
|
| def run_all_temporal( |
| model_name: str, |
| num_events: int, |
| num_examples: int, |
| out_dir: str, |
| ) -> Dict[str, Any]: |
| """Run all three temporal ordering conditions.""" |
| ensure_dir(out_dir) |
| all_results = {} |
|
|
| for order in ["chronological", "reverse", "scrambled"]: |
| logger.info(f"\n--- Temporal Ordering: {order.upper()} ---") |
| all_results[order] = _run_timeline_ordering( |
| model_name, num_events, num_examples, |
| os.path.join(out_dir, order), order, |
| ) |
|
|
| save_json(os.path.join(out_dir, "temporal_master_summary.json"), all_results) |
|
|
| |
| logger.info("\n--- Temporal Ordering Comparison ---") |
| for order, res in all_results.items(): |
| depths = list(res["depths"].keys()) |
| accs = list(res["depths"].values()) |
| if len(accs) >= 3: |
| mid_idx = len(accs) // 2 |
| pbi = (accs[0] + accs[-1]) / 2 - accs[mid_idx] |
| logger.info(f" {order:15s} PBI={pbi:+.3f}") |
|
|
| return all_results |
|
|