File size: 1,997 Bytes
7772fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
"""Temporal Position Bias — Master Runner"""
import argparse
import logging
import os
import sys

from experiments.chronological_vs_reverse import run_all_temporal
from experiments.recency_interaction import run_recency_interaction
from src.utils import save_json

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO,
    stream=sys.stdout,
)
logger = logging.getLogger(__name__)


def parse_args():
    p = argparse.ArgumentParser(description="Temporal Position Bias")
    p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct")
    p.add_argument("--output", default="./results")
    p.add_argument("--num-events", type=int, default=100)
    p.add_argument("--num-examples", type=int, default=30)
    return p.parse_args()


def main():
    args = parse_args()
    model = args.model
    out_root = args.output
    os.makedirs(out_root, exist_ok=True)

    logger.info("\n--- Experiment 1: Temporal Ordering (Chronological vs Reverse vs Scrambled) ---")
    temporal_results = run_all_temporal(
        model, args.num_events, args.num_examples,
        os.path.join(out_root, "exp1_ordering"),
    )

    logger.info("\n--- Experiment 2: Recency × Position Interaction ---")
    recency_results = run_recency_interaction(
        model, args.num_events, args.num_examples,
        os.path.join(out_root, "exp2_recency"),
    )

    master = {
        "temporal_ordering": temporal_results,
        "recency_interaction": recency_results,
    }
    save_json(os.path.join(out_root, "master_summary.json"), master)

    logger.info("\n--- Key Findings ---")
    for order, res in temporal_results.items():
        depths = list(res["depths"].keys())
        accs = list(res["depths"].values())
        if len(accs) >= 3:
            mid_idx = len(accs) // 2
            pbi = (accs[0] + accs[-1]) / 2 - accs[mid_idx]
            logger.info(f"  {order:15s} PBI={pbi:+.3f}")


if __name__ == "__main__":
    main()