abhshkp commited on
Commit
e1e1ce9
·
verified ·
1 Parent(s): 629c011

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,7 +1,3 @@
1
- ---
2
- tags:
3
- - ml-intern
4
- ---
5
  # Lost in the Middle — Benchmark Suite v4
6
 
7
  A modular, reproducible benchmark suite for evaluating **position bias** in long-context language models, extending the original Liu et al. (2023) experiments.
@@ -101,23 +97,3 @@ To add a new experiment:
101
  url={https://huggingface.co/abhshkp/litm-benchmark-suite-v4}
102
  }
103
  ```
104
-
105
- <!-- ml-intern-provenance -->
106
- ## Generated by ML Intern
107
-
108
- This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
109
-
110
- - Try ML Intern: https://smolagents-ml-intern.hf.space
111
- - Source code: https://github.com/huggingface/ml-intern
112
-
113
- ## Usage
114
-
115
- ```python
116
- from transformers import AutoModelForCausalLM, AutoTokenizer
117
-
118
- model_id = "abhshkp/litm-benchmark-suite-v4"
119
- tokenizer = AutoTokenizer.from_pretrained(model_id)
120
- model = AutoModelForCausalLM.from_pretrained(model_id)
121
- ```
122
-
123
- For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
 
 
 
 
 
1
  # Lost in the Middle — Benchmark Suite v4
2
 
3
  A modular, reproducible benchmark suite for evaluating **position bias** in long-context language models, extending the original Liu et al. (2023) experiments.
 
97
  url={https://huggingface.co/abhshkp/litm-benchmark-suite-v4}
98
  }
99
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Experiment modules for LITM Benchmark Suite v4."""
experiments/conversation_memory.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment 7: Conversation Memory
3
+ Critical instruction buried in long chat history.
4
+ """
5
+ import logging
6
+ import os
7
+ import random
8
+ import time
9
+ from typing import List, Dict, Any
10
+
11
+ from tqdm import tqdm
12
+
13
+ from src.generator import generate_text
14
+ from src.metrics import exact_match_score, compute_accuracy, position_bias_index
15
+ from src.plotting import plot_curve
16
+ from src.utils import ensure_dir, save_jsonl, save_json
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ USER_MSGS = [
21
+ "Hello, how are you?",
22
+ "What is the weather like today?",
23
+ "Tell me about quantum physics.",
24
+ "Can you recommend a good book?",
25
+ "What are the health benefits of green tea?",
26
+ "Explain how airplanes fly.",
27
+ "What is the history of the internet?",
28
+ "How do I bake sourdough bread?",
29
+ "What are the best hiking trails in Europe?",
30
+ "Explain neural networks simply.",
31
+ "What is blockchain technology?",
32
+ "How does photosynthesis work?",
33
+ "Tell me a joke.",
34
+ "What is the theory of relativity?",
35
+ "How do vaccines work?",
36
+ "What causes earthquakes?",
37
+ "Explain the water cycle.",
38
+ "What is artificial intelligence?",
39
+ "How do I learn a new language?",
40
+ "What are black holes?",
41
+ ]
42
+
43
+ ASSISTANT_MSGS = [
44
+ "I'm doing well, thank you!",
45
+ "The weather varies by location and season.",
46
+ "Quantum physics studies matter at the smallest scales.",
47
+ "I recommend 'Sapiens' by Yuval Noah Harari.",
48
+ "Green tea contains antioxidants that may boost metabolism.",
49
+ "Airplanes fly due to lift generated by their wings.",
50
+ "The internet evolved from ARPANET in the 1960s.",
51
+ "Sourdough requires flour, water, salt, and a starter culture.",
52
+ "The Tour du Mont Blanc is a spectacular alpine trail.",
53
+ "Neural networks learn patterns from data through layers.",
54
+ "Blockchain is a decentralized digital ledger.",
55
+ "Plants convert sunlight into chemical energy.",
56
+ "Why don't scientists trust atoms? Because they make up everything!",
57
+ "Relativity describes how space and time are interconnected.",
58
+ "Vaccines train the immune system to recognize pathogens.",
59
+ "Earthquakes occur when tectonic plates shift.",
60
+ "Water evaporates, condenses, and precipitates in a cycle.",
61
+ "AI enables machines to perform tasks requiring human intelligence.",
62
+ "Practice daily, immerse yourself, and use spaced repetition.",
63
+ "Black holes have gravitational fields so strong nothing escapes.",
64
+ ]
65
+
66
+
67
+ def _make_conversation(num_turns: int, instruction: str, ratio: float) -> str:
68
+ convo = []
69
+ for i in range(num_turns):
70
+ convo.append(f"User: {random.choice(USER_MSGS)}")
71
+ convo.append(f"Assistant: {random.choice(ASSISTANT_MSGS)}")
72
+
73
+ idx = int(ratio * len(convo))
74
+ convo.insert(idx, f"User: {instruction}")
75
+ convo.insert(idx + 1, "Assistant: I will remember that.")
76
+ return "\n\n".join(convo)
77
+
78
+
79
+ def run_conversation_memory(
80
+ model_name: str,
81
+ num_turns: int,
82
+ num_examples: int,
83
+ out_dir: str,
84
+ depths: List[float] = None,
85
+ ) -> Dict[str, Any]:
86
+ """Run conversation memory experiment."""
87
+ ensure_dir(out_dir)
88
+
89
+ if depths is None:
90
+ depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
91
+
92
+ results = {}
93
+ start = time.time()
94
+
95
+ for depth in depths:
96
+ logger.info(f"[CONVERSATION] Depth {depth:.1%}")
97
+ preds = []
98
+ for i in tqdm(range(num_examples), desc=f"Conversation {depth:.1%}", leave=False):
99
+ secret = f"MYFAVCOLOR-{i:03d}"
100
+ instruction = (
101
+ f"Please always remember that my favorite color is {secret}. "
102
+ f"This is very important."
103
+ )
104
+ convo = _make_conversation(num_turns, instruction, depth)
105
+ prompt = (
106
+ f"Here is our conversation history:\n\n{convo}\n\n"
107
+ f"Based on our conversation, what is my favorite color? "
108
+ f"Answer with only the color code."
109
+ )
110
+ ans = generate_text(
111
+ [{"role": "user", "content": prompt}],
112
+ model_name=model_name,
113
+ max_new_tokens=20,
114
+ )
115
+ correct = exact_match_score(ans, secret)
116
+ preds.append({
117
+ "model_answer": ans,
118
+ "correct": correct,
119
+ "secret": secret,
120
+ "depth": depth,
121
+ })
122
+
123
+ save_jsonl(os.path.join(out_dir, f"conversation_depth_{depth}.jsonl"), preds)
124
+ acc = compute_accuracy(preds)
125
+ results[depth] = {"accuracy": acc, "predictions": preds}
126
+ logger.info(f"[CONVERSATION] Depth {depth:.1%}: acc={acc:.3f}")
127
+
128
+ summary = {
129
+ "experiment": "conversation_memory",
130
+ "num_turns": num_turns,
131
+ "num_examples": num_examples,
132
+ "depths": {str(d): results[d]["accuracy"] for d in depths},
133
+ "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
134
+ "time_minutes": (time.time() - start) / 60,
135
+ }
136
+
137
+ save_json(os.path.join(out_dir, "conversation_summary.json"), summary)
138
+ plot_curve(
139
+ depths,
140
+ [results[d]["accuracy"] for d in depths],
141
+ f"Exp 7: Conversation Memory ({num_turns} turns)",
142
+ os.path.join(out_dir, "conversation_curve.png"),
143
+ xlabel="Depth in Chat History (0=start, 1=end)",
144
+ )
145
+
146
+ logger.info(f"[CONVERSATION] Time={(time.time()-start)/60:.1f} min")
147
+ return summary
experiments/fact_reasoning.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment 4: Fact-Dependent Reasoning
3
+ Math problem requiring a fact hidden at varying depths.
4
+ """
5
+ import logging
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from typing import List, Dict, Any
11
+
12
+ from tqdm import tqdm
13
+
14
+ from src.generator import generate_text
15
+ from src.metrics import numeric_match, compute_accuracy, position_bias_index
16
+ from src.plotting import plot_curve
17
+ from src.utils import ensure_dir, save_jsonl, save_json
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ DISTRACTORS = [
22
+ "The museum opens at 9 AM.",
23
+ "Temperature is recorded hourly.",
24
+ "The container weighs 2,400 kg.",
25
+ "Ordinances ban construction near rivers.",
26
+ "Q3 revenue increased twelve percent.",
27
+ "The database has four million records.",
28
+ "Solar panels generate 45 kWh daily.",
29
+ "The manuscript was translated in the 1800s.",
30
+ "Airport traffic peaks in summer.",
31
+ "The compound melts at 342 Celsius.",
32
+ "Robotic arms have 0.1mm precision.",
33
+ "Fourteen subspecies were identified.",
34
+ "The hall seats 2,800 guests.",
35
+ "Wastewater uses filtration and aeration.",
36
+ "Satellites show drought vegetation.",
37
+ ]
38
+
39
+
40
+ def _make_doc(n: int, fact: str, ratio: float) -> str:
41
+ sents = [random.choice(DISTRACTORS) + f" [Doc {i+1}]" for i in range(n)]
42
+ idx = int(ratio * len(sents))
43
+ sents.insert(idx, fact)
44
+ return " ".join(sents)
45
+
46
+
47
+ def run_fact_reasoning(
48
+ model_name: str,
49
+ num_sentences: int,
50
+ num_examples: int,
51
+ out_dir: str,
52
+ depths: List[float] = None,
53
+ ) -> Dict[str, Any]:
54
+ """Run fact-dependent reasoning experiment."""
55
+ ensure_dir(out_dir)
56
+
57
+ if depths is None:
58
+ depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
59
+
60
+ results = {}
61
+ start = time.time()
62
+
63
+ for depth in depths:
64
+ logger.info(f"[REASON] Depth {depth:.1%}")
65
+ preds = []
66
+ for i in tqdm(range(num_examples), desc=f"Reason {depth:.1%}", leave=False):
67
+ price = random.randint(2, 15)
68
+ qty = random.randint(3, 20)
69
+ discount = random.randint(5, 30)
70
+ answer = round(price * qty * (1 - discount / 100), 2)
71
+ fact = f"For this order, apples cost ${price}/kg with a {discount}% discount."
72
+ doc = _make_doc(num_sentences, fact, depth)
73
+ prompt = (
74
+ f"Use ONLY the document below.\n\n{doc}\n\n"
75
+ f"Question: I buy {qty} kg of apples. What is my total cost? "
76
+ f"Answer with only the dollar amount."
77
+ )
78
+ ans = generate_text(
79
+ [{"role": "user", "content": prompt}],
80
+ model_name=model_name,
81
+ max_new_tokens=30,
82
+ )
83
+ correct = numeric_match(ans, answer, tolerance=0.5)
84
+ preds.append({
85
+ "model_answer": ans,
86
+ "predicted": float(re.findall(r"[\d,]+\.?\d*", ans.replace(",", ""))[0]) if re.findall(r"[\d,]+\.?\d*", ans.replace(",", "")) else -1.0,
87
+ "correct_answer": answer,
88
+ "correct": correct,
89
+ "depth": depth,
90
+ })
91
+
92
+ save_jsonl(os.path.join(out_dir, f"reason_depth_{depth}.jsonl"), preds)
93
+ acc = compute_accuracy(preds)
94
+ results[depth] = {"accuracy": acc, "predictions": preds}
95
+ logger.info(f"[REASON] Depth {depth:.1%}: acc={acc:.3f}")
96
+
97
+ summary = {
98
+ "experiment": "fact_reasoning",
99
+ "num_sentences": num_sentences,
100
+ "num_examples": num_examples,
101
+ "depths": {str(d): results[d]["accuracy"] for d in depths},
102
+ "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
103
+ "time_minutes": (time.time() - start) / 60,
104
+ }
105
+
106
+ save_json(os.path.join(out_dir, "reason_summary.json"), summary)
107
+ plot_curve(
108
+ depths,
109
+ [results[d]["accuracy"] for d in depths],
110
+ f"Exp 4: Fact-Dependent Reasoning ({num_sentences} sentences)",
111
+ os.path.join(out_dir, "reason_curve.png"),
112
+ xlabel="Depth in Document (0=start, 1=end)",
113
+ ylabel="Problem-Solving Accuracy",
114
+ )
115
+
116
+ logger.info(f"[REASON] Time={(time.time()-start)/60:.1f} min")
117
+ return summary
experiments/kv_retrieval.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment 1: Key-Value Retrieval
3
+ Replicates Liu et al. (2023) with expanded position granularity.
4
+ Generates UUID key-value pairs, places gold pair at controlled depths.
5
+ """
6
+ import json
7
+ import logging
8
+ import os
9
+ import random
10
+ import time
11
+ import uuid
12
+ from typing import List, Dict, Any
13
+
14
+ from tqdm import tqdm
15
+
16
+ from src.generator import generate_text
17
+ from src.metrics import exact_match_score, compute_accuracy, position_bias_index
18
+ from src.plotting import plot_curve
19
+ from src.utils import ensure_dir, save_jsonl, save_json
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def _gen_kv_data(num_keys: int, num_examples: int) -> List[Dict[str, Any]]:
25
+ """Generate key-value pair examples."""
26
+ examples = []
27
+ for _ in tqdm(range(num_examples), desc=f"Gen KV data ({num_keys} keys)"):
28
+ kv = {}
29
+ while len(kv) != num_keys:
30
+ kv[str(uuid.uuid4())] = str(uuid.uuid4())
31
+ ordered = list(kv.items())
32
+ gold = random.choice(ordered)
33
+ examples.append({"ordered_kv_records": ordered, "key": gold[0], "value": gold[1]})
34
+ return examples
35
+
36
+
37
+ def _format_prompt(data: List[tuple], key: str) -> str:
38
+ """Format KV data into prompt template."""
39
+ template = """Extract the value corresponding to the specified key in the JSON object below.
40
+
41
+ JSON data:
42
+ {formatted}
43
+
44
+ Key: "{key}"
45
+ Corresponding value:"""
46
+ formatted = ""
47
+ for i, (k, v) in enumerate(data):
48
+ sc = "{" if i == 0 else " "
49
+ ec = ",\n" if i != len(data) - 1 else "}"
50
+ formatted += sc + f'"{k}": "{v}"' + ec
51
+ return template.format(formatted=formatted, key=key)
52
+
53
+
54
+ def _reorder(example: Dict[str, Any], gold_pos: int) -> Dict[str, Any]:
55
+ """Move gold pair to specified position."""
56
+ ordered = example["ordered_kv_records"]
57
+ key = example["key"]
58
+ value = example["value"]
59
+ gi = next(i for i, (k, v) in enumerate(ordered) if k == key)
60
+ new = ordered[:gi] + ordered[gi + 1:]
61
+ new = new[:gold_pos] + [(key, value)] + new[gold_pos:]
62
+ return {"ordered_kv_records": new, "key": key, "value": value}
63
+
64
+
65
+ def run_kv_retrieval(
66
+ model_name: str,
67
+ num_keys: int,
68
+ num_examples: int,
69
+ out_dir: str,
70
+ positions: List[int] = None,
71
+ prefix: str = "kv",
72
+ ) -> Dict[str, Any]:
73
+ """
74
+ Run KV retrieval experiment.
75
+
76
+ Args:
77
+ model_name: HF model identifier
78
+ num_keys: Number of KV pairs
79
+ num_examples: Examples per position
80
+ out_dir: Output directory
81
+ positions: Custom position list (default: 9 positions)
82
+ prefix: Filename prefix
83
+
84
+ Returns:
85
+ Summary dict with accuracy per position and PBI
86
+ """
87
+ ensure_dir(out_dir)
88
+
89
+ if positions is None:
90
+ positions = sorted(set([
91
+ 0,
92
+ num_keys // 8,
93
+ num_keys // 4,
94
+ 3 * num_keys // 8,
95
+ num_keys // 2,
96
+ 5 * num_keys // 8,
97
+ 3 * num_keys // 4,
98
+ 7 * num_keys // 8,
99
+ num_keys - 1,
100
+ ]))
101
+
102
+ # Generate data once, then reorder for each position
103
+ data_path = os.path.join(out_dir, f"{prefix}_data.jsonl")
104
+ examples = _gen_kv_data(num_keys, num_examples)
105
+ save_jsonl(data_path, examples)
106
+
107
+ results = {}
108
+ start = time.time()
109
+
110
+ for pos in positions:
111
+ logger.info(f"[{prefix}] Position {pos}/{num_keys - 1}")
112
+ preds = []
113
+ for ex in tqdm(examples, desc=f"{prefix} pos={pos}", leave=False):
114
+ ro = _reorder(ex, pos)
115
+ prompt = _format_prompt(ro["ordered_kv_records"], ro["key"])
116
+ ans = generate_text(
117
+ [{"role": "user", "content": prompt}],
118
+ model_name=model_name,
119
+ max_new_tokens=80,
120
+ )
121
+ correct = exact_match_score(ans, ro["value"])
122
+ preds.append({
123
+ "model_answer": ans,
124
+ "correct": correct,
125
+ "value": ro["value"],
126
+ "gold_position": pos,
127
+ })
128
+
129
+ save_jsonl(os.path.join(out_dir, f"{prefix}_pos_{pos}.jsonl"), preds)
130
+ acc = compute_accuracy(preds)
131
+ results[pos] = {"accuracy": acc, "predictions": preds}
132
+ logger.info(f"[{prefix}] Pos {pos}: acc={acc:.3f}")
133
+
134
+ # Summary
135
+ norm_pos = [p / (num_keys - 1) for p in sorted(results.keys())]
136
+ accs = [results[p]["accuracy"] for p in sorted(results.keys())]
137
+ pbi = position_bias_index(norm_pos, accs)
138
+
139
+ summary = {
140
+ "experiment": "kv_retrieval",
141
+ "num_keys": num_keys,
142
+ "num_examples": num_examples,
143
+ "positions": {str(p): results[p]["accuracy"] for p in sorted(results.keys())},
144
+ "pbi": pbi,
145
+ "time_minutes": (time.time() - start) / 60,
146
+ }
147
+
148
+ save_json(os.path.join(out_dir, f"{prefix}_summary.json"), summary)
149
+ plot_curve(
150
+ norm_pos, accs,
151
+ f"Exp 1: KV Retrieval ({num_keys} keys)",
152
+ os.path.join(out_dir, f"{prefix}_curve.png"),
153
+ )
154
+
155
+ logger.info(f"[{prefix}] PBI={pbi:.3f} | Time={(time.time()-start)/60:.1f} min")
156
+ return summary
experiments/multi_needle.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment 3: Multi-Needle Retrieval
3
+ Tests ability to retrieve ALL of multiple needles placed at start, middle, and end.
4
+ """
5
+ import logging
6
+ import os
7
+ import random
8
+ import time
9
+ from typing import Dict, Any
10
+
11
+ from tqdm import tqdm
12
+
13
+ from src.generator import generate_text
14
+ from src.metrics import exact_match_score, compute_accuracy
15
+ from src.plotting import plot_bar
16
+ from src.utils import ensure_dir, save_json
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ from .needle_in_haystack import FILLERS
21
+
22
+
23
+ def _make_haystack(n: int) -> str:
24
+ return " ".join(random.choice(FILLERS) + f" [{i+1}]." for i in range(n))
25
+
26
+
27
+ def run_multi_needle(
28
+ model_name: str,
29
+ num_sentences: int,
30
+ num_examples: int,
31
+ out_dir: str,
32
+ ) -> Dict[str, Any]:
33
+ """Run multi-needle experiment."""
34
+ ensure_dir(out_dir)
35
+
36
+ start = time.time()
37
+ start_ok, mid_ok, end_ok = [], [], []
38
+
39
+ for i in tqdm(range(num_examples), desc="Multi-needle"):
40
+ filler = _make_haystack(num_sentences)
41
+ sents = [s.strip() + "." for s in filler.split(".") if s.strip()]
42
+ n = len(sents)
43
+ ca, cb, cc = f"ALPHA-{i:03d}", f"BETA-{i:03d}", f"GAMMA-{i:03d}"
44
+
45
+ sents.insert(0, f"The first secret code is {ca}.")
46
+ sents.insert(n // 2, f"The second secret code is {cb}.")
47
+ sents.append(f"The third secret code is {cc}.")
48
+
49
+ prompt = (
50
+ f"Read the text and list ALL three secret codes in order.\n\n"
51
+ f"{' '.join(sents)}\n\nCodes:"
52
+ )
53
+ ans = generate_text(
54
+ [{"role": "user", "content": prompt}],
55
+ model_name=model_name,
56
+ max_new_tokens=60,
57
+ )
58
+ start_ok.append(exact_match_score(ans, ca))
59
+ mid_ok.append(exact_match_score(ans, cb))
60
+ end_ok.append(exact_match_score(ans, cc))
61
+
62
+ summary = {
63
+ "experiment": "multi_needle",
64
+ "num_sentences": num_sentences,
65
+ "num_examples": num_examples,
66
+ "start": compute_accuracy([{"correct": c} for c in start_ok]),
67
+ "middle": compute_accuracy([{"correct": c} for c in mid_ok]),
68
+ "end": compute_accuracy([{"correct": c} for c in end_ok]),
69
+ "time_minutes": (time.time() - start) / 60,
70
+ }
71
+
72
+ logger.info(
73
+ f"[MULTI] Start={summary['start']:.3f} Mid={summary['middle']:.3f} End={summary['end']:.3f}"
74
+ )
75
+
76
+ save_json(os.path.join(out_dir, "multi_summary.json"), summary)
77
+ plot_bar(
78
+ ["Start", "Middle", "End"],
79
+ [summary["start"], summary["middle"], summary["end"]],
80
+ f"Exp 3: Multi-Needle (n={num_examples})",
81
+ os.path.join(out_dir, "multi_bar.png"),
82
+ )
83
+
84
+ return summary
experiments/needle_in_haystack.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment 2: Needle in Haystack (text)
3
+ Tests retrieval of a secret code hidden at varying depths in filler text.
4
+ """
5
+ import logging
6
+ import os
7
+ import random
8
+ import time
9
+ from typing import List, Dict, Any
10
+
11
+ from tqdm import tqdm
12
+
13
+ from src.generator import generate_text
14
+ from src.metrics import exact_match_score, compute_accuracy, position_bias_index
15
+ from src.plotting import plot_curve
16
+ from src.utils import ensure_dir, save_jsonl, save_json
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ FILLERS = [
21
+ "The history of pottery spans thousands of years.",
22
+ "Marine biologists study coral reef ecosystems.",
23
+ "Railway engineering requires precise curvature.",
24
+ "The periodic table arranges elements by number.",
25
+ "Clouds are classified as cumulus and stratus.",
26
+ "Beekeeping traditions differ between continents.",
27
+ "The Great Wall was built over many dynasties.",
28
+ "Thermodynamics governs heat transfer.",
29
+ "Impressionist painters captured fleeting light.",
30
+ "Volcanic activity is tracked with seismographs.",
31
+ "The Dewey Decimal System organizes libraries.",
32
+ "Irrigation evolved from canals to drip systems.",
33
+ "Neural networks are inspired by biological brains.",
34
+ "Light speed is 299,792,458 meters per second.",
35
+ "Classical composition follows harmonic rules.",
36
+ "Urban planning addresses zoning and transport.",
37
+ "Photosynthesis converts CO2 into glucose.",
38
+ "The Fibonacci sequence appears in nature.",
39
+ "GPS uses triangulation from satellites.",
40
+ "Cryptography secures digital communication.",
41
+ ]
42
+
43
+
44
+ def _make_haystack(n: int) -> str:
45
+ """Generate n sentences of filler text."""
46
+ return " ".join(random.choice(FILLERS) + f" [{i+1}]." for i in range(n))
47
+
48
+
49
+ def _insert_needle(text: str, needle: str, ratio: float) -> str:
50
+ """Insert needle at specified depth ratio."""
51
+ sents = [s.strip() + "." for s in text.split(".") if s.strip()]
52
+ idx = int(ratio * len(sents))
53
+ sents.insert(idx, needle)
54
+ return " ".join(sents)
55
+
56
+
57
+ def run_needle_in_haystack(
58
+ model_name: str,
59
+ num_sentences: int,
60
+ num_examples: int,
61
+ out_dir: str,
62
+ depths: List[float] = None,
63
+ ) -> Dict[str, Any]:
64
+ """Run needle-in-haystack experiment."""
65
+ ensure_dir(out_dir)
66
+
67
+ if depths is None:
68
+ depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
69
+
70
+ results = {}
71
+ start = time.time()
72
+
73
+ for depth in depths:
74
+ logger.info(f"[NEEDLE] Depth {depth:.1%}")
75
+ preds = []
76
+ for i in tqdm(range(num_examples), desc=f"Needle {depth:.1%}", leave=False):
77
+ filler = _make_haystack(num_sentences)
78
+ code = f"SECRET-{i:04d}"
79
+ needle = f"The secret code is {code}."
80
+ text = _insert_needle(filler, needle, depth)
81
+ prompt = (
82
+ f"Read the text and find the secret code.\n\n{text}\n\n"
83
+ f"What is the secret code? Answer with only the code."
84
+ )
85
+ ans = generate_text(
86
+ [{"role": "user", "content": prompt}],
87
+ model_name=model_name,
88
+ max_new_tokens=20,
89
+ )
90
+ correct = exact_match_score(ans, code)
91
+ preds.append({
92
+ "model_answer": ans,
93
+ "correct": correct,
94
+ "secret": code,
95
+ "depth": depth,
96
+ })
97
+
98
+ save_jsonl(os.path.join(out_dir, f"needle_depth_{depth}.jsonl"), preds)
99
+ acc = compute_accuracy(preds)
100
+ results[depth] = {"accuracy": acc, "predictions": preds}
101
+ logger.info(f"[NEEDLE] Depth {depth:.1%}: acc={acc:.3f}")
102
+
103
+ summary = {
104
+ "experiment": "needle_in_haystack",
105
+ "num_sentences": num_sentences,
106
+ "num_examples": num_examples,
107
+ "depths": {str(d): results[d]["accuracy"] for d in depths},
108
+ "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
109
+ "time_minutes": (time.time() - start) / 60,
110
+ }
111
+
112
+ save_json(os.path.join(out_dir, "needle_summary.json"), summary)
113
+ plot_curve(
114
+ depths,
115
+ [results[d]["accuracy"] for d in depths],
116
+ f"Exp 2: Needle in Haystack ({num_sentences} sentences)",
117
+ os.path.join(out_dir, "needle_curve.png"),
118
+ xlabel="Depth in Document (0=start, 1=end)",
119
+ )
120
+
121
+ logger.info(f"[NEEDLE] Time={(time.time()-start)/60:.1f} min")
122
+ return summary
experiments/semantic_distractors.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment 5: Semantic Similarity Distractors
3
+ Gold fact ("capital of France is Paris") among semantically similar facts.
4
+ """
5
+ import logging
6
+ import os
7
+ import random
8
+ import time
9
+ from typing import List, Dict, Any
10
+
11
+ from tqdm import tqdm
12
+
13
+ from src.generator import generate_text
14
+ from src.metrics import exact_match_score, compute_accuracy, position_bias_index
15
+ from src.plotting import plot_curve
16
+ from src.utils import ensure_dir, save_jsonl, save_json
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ TEMPLATES = [
21
+ "The capital of {country} is {city}.",
22
+ "The population of {country} is approximately {num} million.",
23
+ "The official language of {country} is {lang}.",
24
+ "The currency of {country} is the {currency}.",
25
+ "The largest city in {country} is {city}.",
26
+ ]
27
+
28
+ COUNTRIES = [
29
+ "Germany", "Spain", "Italy", "Brazil", "Argentina", "Canada",
30
+ "Australia", "Japan", "China", "India", "Russia", "Egypt",
31
+ "Turkey", "Mexico", "South Korea", "Thailand", "Vietnam",
32
+ "Poland", "Sweden", "Norway", "Denmark", "Finland", "Greece",
33
+ "Portugal", "Ireland", "Austria", "Switzerland", "Belgium",
34
+ "Netherlands", "Czech Republic", "Hungary", "Romania",
35
+ ]
36
+
37
+ CITIES = [
38
+ "Berlin", "Madrid", "Rome", "Brasilia", "Buenos Aires", "Ottawa",
39
+ "Canberra", "Tokyo", "Beijing", "New Delhi", "Moscow", "Cairo",
40
+ "Ankara", "Mexico City", "Seoul", "Bangkok", "Hanoi",
41
+ "Warsaw", "Stockholm", "Oslo", "Copenhagen", "Helsinki", "Athens",
42
+ "Lisbon", "Dublin", "Vienna", "Bern", "Brussels",
43
+ "Amsterdam", "Prague", "Budapest", "Bucharest",
44
+ ]
45
+
46
+ LANGS = [
47
+ "German", "Spanish", "Italian", "Portuguese", "French",
48
+ "English", "Japanese", "Mandarin", "Hindi", "Russian",
49
+ "Arabic", "Turkish", "Korean", "Thai", "Vietnamese",
50
+ "Polish", "Swedish", "Norwegian", "Danish", "Finnish",
51
+ "Greek", "Irish", "Dutch", "Czech", "Hungarian", "Romanian",
52
+ ]
53
+
54
+ CURRENCIES = [
55
+ "Euro", "Peso", "Real", "Dollar", "Yen", "Yuan", "Rupee",
56
+ "Ruble", "Pound", "Won", "Baht", "Dong", "Zloty",
57
+ "Krone", "Krona", "Forint", "Leu", "Franc",
58
+ ]
59
+
60
+
61
+ def _make_doc(num_facts: int, gold_fact: str, ratio: float) -> str:
62
+ facts = []
63
+ for _ in range(num_facts):
64
+ t = random.choice(TEMPLATES)
65
+ fact = t.format(
66
+ country=random.choice(COUNTRIES),
67
+ city=random.choice(CITIES),
68
+ num=random.randint(10, 1400),
69
+ lang=random.choice(LANGS),
70
+ currency=random.choice(CURRENCIES),
71
+ )
72
+ facts.append(fact)
73
+
74
+ idx = int(ratio * len(facts))
75
+ facts.insert(idx, gold_fact)
76
+ return "\n".join(f"{i+1}. {f}" for i, f in enumerate(facts))
77
+
78
+
79
+ def run_semantic_distractors(
80
+ model_name: str,
81
+ num_facts: int,
82
+ num_examples: int,
83
+ out_dir: str,
84
+ depths: List[float] = None,
85
+ ) -> Dict[str, Any]:
86
+ """Run semantic distractor experiment."""
87
+ ensure_dir(out_dir)
88
+
89
+ if depths is None:
90
+ depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
91
+
92
+ results = {}
93
+ start = time.time()
94
+
95
+ for depth in depths:
96
+ logger.info(f"[SEMANTIC] Depth {depth:.1%}")
97
+ preds = []
98
+ for i in tqdm(range(num_examples), desc=f"Semantic {depth:.1%}", leave=False):
99
+ gold = "The capital of France is Paris."
100
+ doc = _make_doc(num_facts, gold, depth)
101
+ prompt = (
102
+ f"Read the following list of facts and answer the question.\n\n{doc}\n\n"
103
+ f"Question: What is the capital of France? Answer with only the city name."
104
+ )
105
+ ans = generate_text(
106
+ [{"role": "user", "content": prompt}],
107
+ model_name=model_name,
108
+ max_new_tokens=20,
109
+ )
110
+ correct = exact_match_score(ans, "paris")
111
+ preds.append({
112
+ "model_answer": ans,
113
+ "correct": correct,
114
+ "depth": depth,
115
+ })
116
+
117
+ save_jsonl(os.path.join(out_dir, f"semantic_depth_{depth}.jsonl"), preds)
118
+ acc = compute_accuracy(preds)
119
+ results[depth] = {"accuracy": acc, "predictions": preds}
120
+ logger.info(f"[SEMANTIC] Depth {depth:.1%}: acc={acc:.3f}")
121
+
122
+ summary = {
123
+ "experiment": "semantic_distractors",
124
+ "num_facts": num_facts,
125
+ "num_examples": num_examples,
126
+ "depths": {str(d): results[d]["accuracy"] for d in depths},
127
+ "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
128
+ "time_minutes": (time.time() - start) / 60,
129
+ }
130
+
131
+ save_json(os.path.join(out_dir, "semantic_summary.json"), summary)
132
+ plot_curve(
133
+ depths,
134
+ [results[d]["accuracy"] for d in depths],
135
+ f"Exp 5: Semantic Similarity Distractors ({num_facts} facts)",
136
+ os.path.join(out_dir, "semantic_curve.png"),
137
+ xlabel="Depth in Document (0=start, 1=end)",
138
+ )
139
+
140
+ logger.info(f"[SEMANTIC] Time={(time.time()-start)/60:.1f} min")
141
+ return summary
experiments/temporal_narrative.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experiment 6: Temporal Narrative
3
+ Recall an event from a long chronological timeline.
4
+ """
5
+ import logging
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from typing import List, Dict, Any
11
+
12
+ from tqdm import tqdm
13
+
14
+ from src.generator import generate_text
15
+ from src.metrics import exact_match_score, compute_accuracy, position_bias_index
16
+ from src.plotting import plot_curve
17
+ from src.utils import ensure_dir, save_jsonl, save_json
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ EVENTS_POOL = [
22
+ "the king issued a decree",
23
+ "a comet appeared in the sky",
24
+ "the bridge was completed",
25
+ "a treaty was signed",
26
+ "the harvest festival began",
27
+ "a stranger arrived at the gates",
28
+ "the library burned down",
29
+ "a new star was discovered",
30
+ "the river flooded the town",
31
+ "the army marched north",
32
+ "a peace envoy was sent",
33
+ "the market was opened",
34
+ "a plague swept the city",
35
+ "the old temple was restored",
36
+ "a fleet set sail for distant lands",
37
+ "the academy admitted its first students",
38
+ "a rebellion broke out in the east",
39
+ "the queen gave birth to twins",
40
+ "a dragon was spotted in the mountains",
41
+ "the great bell tolled for the first time",
42
+ ]
43
+
44
+
45
+ def _make_timeline(num_events: int, target_event: str, ratio: float) -> str:
46
+ events = random.sample(EVENTS_POOL, min(num_events, len(EVENTS_POOL)))
47
+ while len(events) < num_events:
48
+ events.append(
49
+ f"the people gathered for the {random.choice(['morning', 'evening', 'midday'])} ceremony"
50
+ )
51
+ idx = int(ratio * len(events))
52
+ events.insert(idx, target_event)
53
+ return "\n".join(f"Year {1000+i}: {e}." for i, e in enumerate(events))
54
+
55
+
56
+ def run_temporal_narrative(
57
+ model_name: str,
58
+ num_events: int,
59
+ num_examples: int,
60
+ out_dir: str,
61
+ depths: List[float] = None,
62
+ ) -> Dict[str, Any]:
63
+ """Run temporal narrative experiment."""
64
+ ensure_dir(out_dir)
65
+
66
+ if depths is None:
67
+ depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
68
+
69
+ results = {}
70
+ start = time.time()
71
+
72
+ for depth in depths:
73
+ logger.info(f"[NARRATIVE] Depth {depth:.1%}")
74
+ preds = []
75
+ for i in tqdm(range(num_examples), desc=f"Narrative {depth:.1%}", leave=False):
76
+ target = "a golden statue was unveiled in the central square"
77
+ timeline = _make_timeline(num_events, target, depth)
78
+ prompt = (
79
+ f"Read the following timeline of historical events.\n\n{timeline}\n\n"
80
+ f"Question: In which year was a golden statue unveiled in the central square? "
81
+ f"Answer with only the year number."
82
+ )
83
+ ans = generate_text(
84
+ [{"role": "user", "content": prompt}],
85
+ model_name=model_name,
86
+ max_new_tokens=15,
87
+ )
88
+ expected_year = 1000 + int(depth * num_events)
89
+ years = re.findall(r"\b\d{4}\b", ans)
90
+ correct = 1.0 if any(abs(int(y) - expected_year) < 5 for y in years) else 0.0
91
+ preds.append({
92
+ "model_answer": ans,
93
+ "correct": correct,
94
+ "expected_year": expected_year,
95
+ "depth": depth,
96
+ })
97
+
98
+ save_jsonl(os.path.join(out_dir, f"narrative_depth_{depth}.jsonl"), preds)
99
+ acc = compute_accuracy(preds)
100
+ results[depth] = {"accuracy": acc, "predictions": preds}
101
+ logger.info(f"[NARRATIVE] Depth {depth:.1%}: acc={acc:.3f}")
102
+
103
+ summary = {
104
+ "experiment": "temporal_narrative",
105
+ "num_events": num_events,
106
+ "num_examples": num_examples,
107
+ "depths": {str(d): results[d]["accuracy"] for d in depths},
108
+ "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
109
+ "time_minutes": (time.time() - start) / 60,
110
+ }
111
+
112
+ save_json(os.path.join(out_dir, "narrative_summary.json"), summary)
113
+ plot_curve(
114
+ depths,
115
+ [results[d]["accuracy"] for d in depths],
116
+ f"Exp 6: Temporal Narrative ({num_events} events)",
117
+ os.path.join(out_dir, "narrative_curve.png"),
118
+ xlabel="Depth in Timeline (0=start, 1=end)",
119
+ )
120
+
121
+ logger.info(f"[NARRATIVE] Time={(time.time()-start)/60:.1f} min")
122
+ return summary
run_all.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ================================================================================
4
+ LOST IN THE MIDDLE — Benchmark Suite v4 (Master Runner)
5
+ ================================================================================
6
+ Runs all 7 experiments with configurable model, counts, and output directory.
7
+ Usage:
8
+ python run_all.py --model Qwen/Qwen2.5-1.5B-Instruct --output ./results
9
+ ================================================================================
10
+ """
11
+ import argparse
12
+ import json
13
+ import logging
14
+ import os
15
+ import shutil
16
+ import sys
17
+ import time
18
+
19
+ from experiments.kv_retrieval import run_kv_retrieval
20
+ from experiments.needle_in_haystack import run_needle_in_haystack
21
+ from experiments.multi_needle import run_multi_needle
22
+ from experiments.fact_reasoning import run_fact_reasoning
23
+ from experiments.semantic_distractors import run_semantic_distractors
24
+ from experiments.temporal_narrative import run_temporal_narrative
25
+ from experiments.conversation_memory import run_conversation_memory
26
+ from src.utils import save_json
27
+
28
+ logging.basicConfig(
29
+ format="%(asctime)s - %(levelname)s - %(message)s",
30
+ level=logging.INFO,
31
+ stream=sys.stdout,
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def parse_args():
37
+ p = argparse.ArgumentParser(description="LITM Benchmark Suite v4")
38
+ p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct", help="HF model name")
39
+ p.add_argument("--output", default="./results", help="Output directory")
40
+ p.add_argument("--n-examples", type=int, default=50, help="Examples per position")
41
+ p.add_argument("--n-keys-100", type=int, default=100)
42
+ p.add_argument("--n-keys-200", type=int, default=200)
43
+ p.add_argument("--needle-sentences", type=int, default=500)
44
+ p.add_argument("--multi-sentences", type=int, default=300)
45
+ p.add_argument("--reason-sentences", type=int, default=300)
46
+ p.add_argument("--semantic-facts", type=int, default=80)
47
+ p.add_argument("--narrative-events", type=int, default=100)
48
+ p.add_argument("--convo-turns", type=int, default=100)
49
+ p.add_argument("--experiments", default="all", help="Comma-separated list or 'all'")
50
+ p.add_argument("--zip", action="store_true", help="Create zip archive of results")
51
+ return p.parse_args()
52
+
53
+
54
+ def main():
55
+ args = parse_args()
56
+ model = args.model
57
+ out_root = args.output
58
+ os.makedirs(out_root, exist_ok=True)
59
+
60
+ wanted = set(args.experiments.split(",")) if args.experiments != "all" else {"all"}
61
+
62
+ logger.info("=" * 70)
63
+ logger.info("LITM BENCHMARK SUITE v4")
64
+ logger.info(f"Model: {model} | Output: {out_root}")
65
+ logger.info("=" * 70)
66
+
67
+ all_results = {}
68
+ t0 = time.time()
69
+
70
+ def should_run(name):
71
+ return "all" in wanted or name in wanted
72
+
73
+ if should_run("kv100"):
74
+ logger.info("\n--- EXP 1A: KV Retrieval (100 keys) ---")
75
+ all_results["kv_100"] = run_kv_retrieval(
76
+ model_name=model,
77
+ num_keys=args.n_keys_100,
78
+ num_examples=args.n_examples,
79
+ out_dir=os.path.join(out_root, "exp1a_kv100"),
80
+ prefix="kv100",
81
+ )
82
+
83
+ if should_run("kv200"):
84
+ logger.info("\n--- EXP 1B: KV Retrieval (200 keys) ---")
85
+ all_results["kv_200"] = run_kv_retrieval(
86
+ model_name=model,
87
+ num_keys=args.n_keys_200,
88
+ num_examples=args.n_examples,
89
+ out_dir=os.path.join(out_root, "exp1b_kv200"),
90
+ prefix="kv200",
91
+ )
92
+
93
+ if should_run("needle"):
94
+ logger.info("\n--- EXP 2: Needle in Haystack ---")
95
+ all_results["needle"] = run_needle_in_haystack(
96
+ model_name=model,
97
+ num_sentences=args.needle_sentences,
98
+ num_examples=30,
99
+ out_dir=os.path.join(out_root, "exp2_needle"),
100
+ )
101
+
102
+ if should_run("multi"):
103
+ logger.info("\n--- EXP 3: Multi-Needle ---")
104
+ all_results["multi"] = run_multi_needle(
105
+ model_name=model,
106
+ num_sentences=args.multi_sentences,
107
+ num_examples=30,
108
+ out_dir=os.path.join(out_root, "exp3_multi"),
109
+ )
110
+
111
+ if should_run("reason"):
112
+ logger.info("\n--- EXP 4: Fact-Dependent Reasoning ---")
113
+ all_results["reason"] = run_fact_reasoning(
114
+ model_name=model,
115
+ num_sentences=args.reason_sentences,
116
+ num_examples=30,
117
+ out_dir=os.path.join(out_root, "exp4_reason"),
118
+ )
119
+
120
+ if should_run("semantic"):
121
+ logger.info("\n--- EXP 5: Semantic Similarity Distractors ---")
122
+ all_results["semantic"] = run_semantic_distractors(
123
+ model_name=model,
124
+ num_facts=args.semantic_facts,
125
+ num_examples=30,
126
+ out_dir=os.path.join(out_root, "exp5_semantic"),
127
+ )
128
+
129
+ if should_run("narrative"):
130
+ logger.info("\n--- EXP 6: Temporal Narrative ---")
131
+ all_results["narrative"] = run_temporal_narrative(
132
+ model_name=model,
133
+ num_events=args.narrative_events,
134
+ num_examples=30,
135
+ out_dir=os.path.join(out_root, "exp6_narrative"),
136
+ )
137
+
138
+ if should_run("conversation"):
139
+ logger.info("\n--- EXP 7: Conversation Memory ---")
140
+ all_results["conversation"] = run_conversation_memory(
141
+ model_name=model,
142
+ num_turns=args.convo_turns,
143
+ num_examples=30,
144
+ out_dir=os.path.join(out_root, "exp7_conversation"),
145
+ )
146
+
147
+ elapsed = (time.time() - t0) / 3600
148
+ logger.info(f"\n{'='*70}")
149
+ logger.info(f"COMPLETE. Total time: {elapsed:.2f} hours")
150
+ logger.info(f"Results: {out_root}")
151
+ logger.info(f"{'='*70}")
152
+
153
+ save_json(os.path.join(out_root, "master_summary.json"), all_results)
154
+
155
+ # Print PBI table
156
+ logger.info("\n--- Position Bias Index (PBI) Summary ---")
157
+ for k, v in all_results.items():
158
+ if isinstance(v, dict) and "pbi" in v:
159
+ logger.info(f" {k:20s} PBI = {v['pbi']:+.3f}")
160
+
161
+ if args.zip:
162
+ zip_path = os.path.join(os.path.dirname(out_root), "litm_results_all")
163
+ shutil.make_archive(zip_path, "zip", out_root)
164
+ logger.info(f"Zipped: {zip_path}.zip")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """LITM Benchmark Suite v4 — Core Library"""
2
+ __version__ = "4.0.0"
3
+ __author__ = "abhshkp"
src/generator.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text generation wrapper with chat-template support."""
2
+ import logging
3
+ import torch
4
+ from typing import List, Dict, Any
5
+ from .model_loader import load_model
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def generate_text(
11
+ messages: List[Dict[str, str]],
12
+ model_name: str,
13
+ max_new_tokens: int = 80,
14
+ load_in_4bit: bool = True,
15
+ ) -> str:
16
+ """Generate text from a chat-formatted message list."""
17
+ model, tokenizer = load_model(model_name, load_in_4bit=load_in_4bit)
18
+
19
+ inputs = tokenizer.apply_chat_template(
20
+ messages,
21
+ tokenize=True,
22
+ return_tensors="pt",
23
+ add_generation_prompt=True,
24
+ return_dict=True,
25
+ )
26
+
27
+ dev = next(model.parameters()).device
28
+ inputs = {k: v.to(dev) for k, v in inputs.items()}
29
+
30
+ with torch.no_grad():
31
+ outputs = model.generate(
32
+ **inputs,
33
+ max_new_tokens=max_new_tokens,
34
+ do_sample=False,
35
+ pad_token_id=tokenizer.pad_token_id,
36
+ )
37
+
38
+ gen = outputs[0][inputs["input_ids"].shape[1]:]
39
+ return tokenizer.decode(gen, skip_special_tokens=True).strip()
src/metrics.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metrics and scoring utilities."""
2
+ import re
3
+ import statistics
4
+ from typing import List, Dict, Any
5
+
6
+
7
+ def exact_match_score(prediction: str, target: str) -> float:
8
+ """Binary exact-match score."""
9
+ return 1.0 if target.lower() in prediction.lower() else 0.0
10
+
11
+
12
+ def numeric_match(prediction: str, target: float, tolerance: float = 0.5) -> float:
13
+ """Extract first number from prediction and compare with tolerance."""
14
+ nums = re.findall(r"[\d,]+\.?\d*", prediction.replace(",", ""))
15
+ if not nums:
16
+ return 0.0
17
+ pred = float(nums[0])
18
+ return 1.0 if abs(pred - target) < tolerance else 0.0
19
+
20
+
21
+ def compute_accuracy(predictions: List[Dict[str, Any]], key: str = "correct") -> float:
22
+ """Mean accuracy from a list of prediction records."""
23
+ vals = [p[key] for p in predictions]
24
+ return statistics.mean(vals) if vals else 0.0
25
+
26
+
27
+ def position_bias_index(positions: List[float], accuracies: List[float]) -> float:
28
+ """
29
+ Compute Position Bias Index (PBI):
30
+ PBI = (acc_first + acc_last) / 2 - acc_middle
31
+ Higher PBI = stronger U-shape (worse).
32
+ """
33
+ if len(positions) < 3:
34
+ return 0.0
35
+ mid_idx = len(positions) // 2
36
+ edge_acc = (accuracies[0] + accuracies[-1]) / 2.0
37
+ mid_acc = accuracies[mid_idx]
38
+ return edge_acc - mid_acc
src/model_loader.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model loading with 4-bit quantization for T4/GPU inference."""
2
+ import logging
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ _model_cache = {}
9
+ _tok_cache = {}
10
+
11
+
12
+ def load_model(model_name: str, load_in_4bit: bool = True, device_map: str = "auto"):
13
+ """Load model with optional 4-bit quantization. Cached for reuse."""
14
+ cache_key = f"{model_name}:{load_in_4bit}:{device_map}"
15
+ if cache_key in _model_cache:
16
+ return _model_cache[cache_key], _tok_cache[cache_key]
17
+
18
+ logger.info(f"Loading model: {model_name}")
19
+ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
20
+ if tok.pad_token is None:
21
+ tok.pad_token = tok.eos_token
22
+
23
+ if load_in_4bit:
24
+ bnb = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_compute_dtype=torch.bfloat16,
27
+ bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ )
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_name,
32
+ quantization_config=bnb,
33
+ device_map=device_map,
34
+ trust_remote_code=True,
35
+ torch_dtype=torch.bfloat16,
36
+ )
37
+ else:
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ model_name,
40
+ device_map=device_map,
41
+ trust_remote_code=True,
42
+ torch_dtype=torch.bfloat16,
43
+ )
44
+
45
+ model.eval()
46
+ dev = next(model.parameters()).device
47
+ logger.info(f"Model loaded on {dev}")
48
+
49
+ _model_cache[cache_key] = model
50
+ _tok_cache[cache_key] = tok
51
+ return model, tok
src/plotting.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plotting utilities for position-bias curves."""
2
+ import logging
3
+ import matplotlib
4
+ matplotlib.use("Agg")
5
+ import matplotlib.pyplot as plt
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def plot_curve(
11
+ x_values,
12
+ y_values,
13
+ title: str,
14
+ save_path: str,
15
+ xlabel: str = "Position (0=start, 1=end)",
16
+ ylabel: str = "Accuracy",
17
+ ylim: tuple = (-0.05, 1.05),
18
+ color: str = "#E63946",
19
+ ):
20
+ """Plot a standard position-bias accuracy curve."""
21
+ plt.figure(figsize=(8, 5))
22
+ plt.plot(x_values, y_values, marker="o", linewidth=2.5, markersize=10, color=color)
23
+ plt.xlabel(xlabel, fontsize=13)
24
+ plt.ylabel(ylabel, fontsize=13)
25
+ plt.title(title, fontsize=13)
26
+ plt.ylim(ylim)
27
+ plt.grid(True, alpha=0.3)
28
+ plt.tight_layout()
29
+ plt.savefig(save_path, dpi=200)
30
+ plt.close()
31
+ logger.info(f"Plot saved: {save_path}")
32
+
33
+
34
+ def plot_bar(categories, values, title: str, save_path: str, ylabel: str = "Accuracy", ylim=(0, 1.05), colors=None):
35
+ """Plot a bar chart (e.g., for multi-needle start/middle/end)."""
36
+ if colors is None:
37
+ colors = ["#2E86AB", "#E63946", "#2E86AB"]
38
+ plt.figure(figsize=(6, 5))
39
+ plt.bar(categories, values, color=colors, edgecolor="black", linewidth=1.2)
40
+ plt.ylabel(ylabel, fontsize=13)
41
+ plt.title(title, fontsize=13)
42
+ plt.ylim(ylim)
43
+ plt.grid(True, alpha=0.3, axis="y")
44
+ plt.tight_layout()
45
+ plt.savefig(save_path, dpi=200)
46
+ plt.close()
47
+ logger.info(f"Bar plot saved: {save_path}")
48
+
49
+
50
+ def plot_multi_curves(curves, labels, title, save_path, xlabel="Position", ylabel="Accuracy"):
51
+ """Overlay multiple curves for comparison."""
52
+ plt.figure(figsize=(10, 6))
53
+ cmap = plt.get_cmap("tab10")
54
+ for i, (x, y, label) in enumerate(zip(curves["x"], curves["y"], labels)):
55
+ plt.plot(x, y, marker="o", linewidth=2.0, markersize=8, label=label, color=cmap(i))
56
+ plt.xlabel(xlabel, fontsize=13)
57
+ plt.ylabel(ylabel, fontsize=13)
58
+ plt.title(title, fontsize=13)
59
+ plt.ylim(-0.05, 1.05)
60
+ plt.legend()
61
+ plt.grid(True, alpha=0.3)
62
+ plt.tight_layout()
63
+ plt.savefig(save_path, dpi=200)
64
+ plt.close()
65
+ logger.info(f"Multi-curve plot saved: {save_path}")
src/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Common utilities."""
2
+ import json
3
+ import os
4
+ import logging
5
+ from typing import List, Dict, Any
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def ensure_dir(path: str):
11
+ """Create directory if it doesn't exist."""
12
+ os.makedirs(path, exist_ok=True)
13
+
14
+
15
+ def save_jsonl(path: str, records: List[Dict[str, Any]]):
16
+ """Save records as JSONL."""
17
+ with open(path, "w") as f:
18
+ for r in records:
19
+ f.write(json.dumps(r) + "\n")
20
+
21
+
22
+ def load_jsonl(path: str) -> List[Dict[str, Any]]:
23
+ """Load JSONL records."""
24
+ records = []
25
+ with open(path) as f:
26
+ for line in f:
27
+ records.append(json.loads(line))
28
+ return records
29
+
30
+
31
+ def save_json(path: str, data: Any):
32
+ """Save data as pretty-printed JSON."""
33
+ with open(path, "w") as f:
34
+ json.dump(data, f, indent=2)
35
+
36
+
37
+ def load_json(path: str) -> Any:
38
+ """Load JSON file."""
39
+ with open(path) as f:
40
+ return json.load(f)