Upload folder using huggingface_hub
Browse files- README.md +0 -24
- experiments/__init__.py +1 -0
- experiments/conversation_memory.py +147 -0
- experiments/fact_reasoning.py +117 -0
- experiments/kv_retrieval.py +156 -0
- experiments/multi_needle.py +84 -0
- experiments/needle_in_haystack.py +122 -0
- experiments/semantic_distractors.py +141 -0
- experiments/temporal_narrative.py +122 -0
- run_all.py +168 -0
- src/__init__.py +3 -0
- src/generator.py +39 -0
- src/metrics.py +38 -0
- src/model_loader.py +51 -0
- src/plotting.py +65 -0
- src/utils.py +40 -0
README.md
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
---
|
| 2 |
-
tags:
|
| 3 |
-
- ml-intern
|
| 4 |
-
---
|
| 5 |
# Lost in the Middle — Benchmark Suite v4
|
| 6 |
|
| 7 |
A modular, reproducible benchmark suite for evaluating **position bias** in long-context language models, extending the original Liu et al. (2023) experiments.
|
|
@@ -101,23 +97,3 @@ To add a new experiment:
|
|
| 101 |
url={https://huggingface.co/abhshkp/litm-benchmark-suite-v4}
|
| 102 |
}
|
| 103 |
```
|
| 104 |
-
|
| 105 |
-
<!-- ml-intern-provenance -->
|
| 106 |
-
## Generated by ML Intern
|
| 107 |
-
|
| 108 |
-
This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
|
| 109 |
-
|
| 110 |
-
- Try ML Intern: https://smolagents-ml-intern.hf.space
|
| 111 |
-
- Source code: https://github.com/huggingface/ml-intern
|
| 112 |
-
|
| 113 |
-
## Usage
|
| 114 |
-
|
| 115 |
-
```python
|
| 116 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 117 |
-
|
| 118 |
-
model_id = "abhshkp/litm-benchmark-suite-v4"
|
| 119 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 120 |
-
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 121 |
-
```
|
| 122 |
-
|
| 123 |
-
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Lost in the Middle — Benchmark Suite v4
|
| 2 |
|
| 3 |
A modular, reproducible benchmark suite for evaluating **position bias** in long-context language models, extending the original Liu et al. (2023) experiments.
|
|
|
|
| 97 |
url={https://huggingface.co/abhshkp/litm-benchmark-suite-v4}
|
| 98 |
}
|
| 99 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
experiments/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Experiment modules for LITM Benchmark Suite v4."""
|
experiments/conversation_memory.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment 7: Conversation Memory
|
| 3 |
+
Critical instruction buried in long chat history.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from src.generator import generate_text
|
| 14 |
+
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
|
| 15 |
+
from src.plotting import plot_curve
|
| 16 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
USER_MSGS = [
|
| 21 |
+
"Hello, how are you?",
|
| 22 |
+
"What is the weather like today?",
|
| 23 |
+
"Tell me about quantum physics.",
|
| 24 |
+
"Can you recommend a good book?",
|
| 25 |
+
"What are the health benefits of green tea?",
|
| 26 |
+
"Explain how airplanes fly.",
|
| 27 |
+
"What is the history of the internet?",
|
| 28 |
+
"How do I bake sourdough bread?",
|
| 29 |
+
"What are the best hiking trails in Europe?",
|
| 30 |
+
"Explain neural networks simply.",
|
| 31 |
+
"What is blockchain technology?",
|
| 32 |
+
"How does photosynthesis work?",
|
| 33 |
+
"Tell me a joke.",
|
| 34 |
+
"What is the theory of relativity?",
|
| 35 |
+
"How do vaccines work?",
|
| 36 |
+
"What causes earthquakes?",
|
| 37 |
+
"Explain the water cycle.",
|
| 38 |
+
"What is artificial intelligence?",
|
| 39 |
+
"How do I learn a new language?",
|
| 40 |
+
"What are black holes?",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
ASSISTANT_MSGS = [
|
| 44 |
+
"I'm doing well, thank you!",
|
| 45 |
+
"The weather varies by location and season.",
|
| 46 |
+
"Quantum physics studies matter at the smallest scales.",
|
| 47 |
+
"I recommend 'Sapiens' by Yuval Noah Harari.",
|
| 48 |
+
"Green tea contains antioxidants that may boost metabolism.",
|
| 49 |
+
"Airplanes fly due to lift generated by their wings.",
|
| 50 |
+
"The internet evolved from ARPANET in the 1960s.",
|
| 51 |
+
"Sourdough requires flour, water, salt, and a starter culture.",
|
| 52 |
+
"The Tour du Mont Blanc is a spectacular alpine trail.",
|
| 53 |
+
"Neural networks learn patterns from data through layers.",
|
| 54 |
+
"Blockchain is a decentralized digital ledger.",
|
| 55 |
+
"Plants convert sunlight into chemical energy.",
|
| 56 |
+
"Why don't scientists trust atoms? Because they make up everything!",
|
| 57 |
+
"Relativity describes how space and time are interconnected.",
|
| 58 |
+
"Vaccines train the immune system to recognize pathogens.",
|
| 59 |
+
"Earthquakes occur when tectonic plates shift.",
|
| 60 |
+
"Water evaporates, condenses, and precipitates in a cycle.",
|
| 61 |
+
"AI enables machines to perform tasks requiring human intelligence.",
|
| 62 |
+
"Practice daily, immerse yourself, and use spaced repetition.",
|
| 63 |
+
"Black holes have gravitational fields so strong nothing escapes.",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _make_conversation(num_turns: int, instruction: str, ratio: float) -> str:
|
| 68 |
+
convo = []
|
| 69 |
+
for i in range(num_turns):
|
| 70 |
+
convo.append(f"User: {random.choice(USER_MSGS)}")
|
| 71 |
+
convo.append(f"Assistant: {random.choice(ASSISTANT_MSGS)}")
|
| 72 |
+
|
| 73 |
+
idx = int(ratio * len(convo))
|
| 74 |
+
convo.insert(idx, f"User: {instruction}")
|
| 75 |
+
convo.insert(idx + 1, "Assistant: I will remember that.")
|
| 76 |
+
return "\n\n".join(convo)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def run_conversation_memory(
|
| 80 |
+
model_name: str,
|
| 81 |
+
num_turns: int,
|
| 82 |
+
num_examples: int,
|
| 83 |
+
out_dir: str,
|
| 84 |
+
depths: List[float] = None,
|
| 85 |
+
) -> Dict[str, Any]:
|
| 86 |
+
"""Run conversation memory experiment."""
|
| 87 |
+
ensure_dir(out_dir)
|
| 88 |
+
|
| 89 |
+
if depths is None:
|
| 90 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 91 |
+
|
| 92 |
+
results = {}
|
| 93 |
+
start = time.time()
|
| 94 |
+
|
| 95 |
+
for depth in depths:
|
| 96 |
+
logger.info(f"[CONVERSATION] Depth {depth:.1%}")
|
| 97 |
+
preds = []
|
| 98 |
+
for i in tqdm(range(num_examples), desc=f"Conversation {depth:.1%}", leave=False):
|
| 99 |
+
secret = f"MYFAVCOLOR-{i:03d}"
|
| 100 |
+
instruction = (
|
| 101 |
+
f"Please always remember that my favorite color is {secret}. "
|
| 102 |
+
f"This is very important."
|
| 103 |
+
)
|
| 104 |
+
convo = _make_conversation(num_turns, instruction, depth)
|
| 105 |
+
prompt = (
|
| 106 |
+
f"Here is our conversation history:\n\n{convo}\n\n"
|
| 107 |
+
f"Based on our conversation, what is my favorite color? "
|
| 108 |
+
f"Answer with only the color code."
|
| 109 |
+
)
|
| 110 |
+
ans = generate_text(
|
| 111 |
+
[{"role": "user", "content": prompt}],
|
| 112 |
+
model_name=model_name,
|
| 113 |
+
max_new_tokens=20,
|
| 114 |
+
)
|
| 115 |
+
correct = exact_match_score(ans, secret)
|
| 116 |
+
preds.append({
|
| 117 |
+
"model_answer": ans,
|
| 118 |
+
"correct": correct,
|
| 119 |
+
"secret": secret,
|
| 120 |
+
"depth": depth,
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
save_jsonl(os.path.join(out_dir, f"conversation_depth_{depth}.jsonl"), preds)
|
| 124 |
+
acc = compute_accuracy(preds)
|
| 125 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 126 |
+
logger.info(f"[CONVERSATION] Depth {depth:.1%}: acc={acc:.3f}")
|
| 127 |
+
|
| 128 |
+
summary = {
|
| 129 |
+
"experiment": "conversation_memory",
|
| 130 |
+
"num_turns": num_turns,
|
| 131 |
+
"num_examples": num_examples,
|
| 132 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 133 |
+
"pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
|
| 134 |
+
"time_minutes": (time.time() - start) / 60,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
save_json(os.path.join(out_dir, "conversation_summary.json"), summary)
|
| 138 |
+
plot_curve(
|
| 139 |
+
depths,
|
| 140 |
+
[results[d]["accuracy"] for d in depths],
|
| 141 |
+
f"Exp 7: Conversation Memory ({num_turns} turns)",
|
| 142 |
+
os.path.join(out_dir, "conversation_curve.png"),
|
| 143 |
+
xlabel="Depth in Chat History (0=start, 1=end)",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
logger.info(f"[CONVERSATION] Time={(time.time()-start)/60:.1f} min")
|
| 147 |
+
return summary
|
experiments/fact_reasoning.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment 4: Fact-Dependent Reasoning
|
| 3 |
+
Math problem requiring a fact hidden at varying depths.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
from typing import List, Dict, Any
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from src.generator import generate_text
|
| 15 |
+
from src.metrics import numeric_match, compute_accuracy, position_bias_index
|
| 16 |
+
from src.plotting import plot_curve
|
| 17 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
DISTRACTORS = [
|
| 22 |
+
"The museum opens at 9 AM.",
|
| 23 |
+
"Temperature is recorded hourly.",
|
| 24 |
+
"The container weighs 2,400 kg.",
|
| 25 |
+
"Ordinances ban construction near rivers.",
|
| 26 |
+
"Q3 revenue increased twelve percent.",
|
| 27 |
+
"The database has four million records.",
|
| 28 |
+
"Solar panels generate 45 kWh daily.",
|
| 29 |
+
"The manuscript was translated in the 1800s.",
|
| 30 |
+
"Airport traffic peaks in summer.",
|
| 31 |
+
"The compound melts at 342 Celsius.",
|
| 32 |
+
"Robotic arms have 0.1mm precision.",
|
| 33 |
+
"Fourteen subspecies were identified.",
|
| 34 |
+
"The hall seats 2,800 guests.",
|
| 35 |
+
"Wastewater uses filtration and aeration.",
|
| 36 |
+
"Satellites show drought vegetation.",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _make_doc(n: int, fact: str, ratio: float) -> str:
|
| 41 |
+
sents = [random.choice(DISTRACTORS) + f" [Doc {i+1}]" for i in range(n)]
|
| 42 |
+
idx = int(ratio * len(sents))
|
| 43 |
+
sents.insert(idx, fact)
|
| 44 |
+
return " ".join(sents)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def run_fact_reasoning(
|
| 48 |
+
model_name: str,
|
| 49 |
+
num_sentences: int,
|
| 50 |
+
num_examples: int,
|
| 51 |
+
out_dir: str,
|
| 52 |
+
depths: List[float] = None,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
"""Run fact-dependent reasoning experiment."""
|
| 55 |
+
ensure_dir(out_dir)
|
| 56 |
+
|
| 57 |
+
if depths is None:
|
| 58 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 59 |
+
|
| 60 |
+
results = {}
|
| 61 |
+
start = time.time()
|
| 62 |
+
|
| 63 |
+
for depth in depths:
|
| 64 |
+
logger.info(f"[REASON] Depth {depth:.1%}")
|
| 65 |
+
preds = []
|
| 66 |
+
for i in tqdm(range(num_examples), desc=f"Reason {depth:.1%}", leave=False):
|
| 67 |
+
price = random.randint(2, 15)
|
| 68 |
+
qty = random.randint(3, 20)
|
| 69 |
+
discount = random.randint(5, 30)
|
| 70 |
+
answer = round(price * qty * (1 - discount / 100), 2)
|
| 71 |
+
fact = f"For this order, apples cost ${price}/kg with a {discount}% discount."
|
| 72 |
+
doc = _make_doc(num_sentences, fact, depth)
|
| 73 |
+
prompt = (
|
| 74 |
+
f"Use ONLY the document below.\n\n{doc}\n\n"
|
| 75 |
+
f"Question: I buy {qty} kg of apples. What is my total cost? "
|
| 76 |
+
f"Answer with only the dollar amount."
|
| 77 |
+
)
|
| 78 |
+
ans = generate_text(
|
| 79 |
+
[{"role": "user", "content": prompt}],
|
| 80 |
+
model_name=model_name,
|
| 81 |
+
max_new_tokens=30,
|
| 82 |
+
)
|
| 83 |
+
correct = numeric_match(ans, answer, tolerance=0.5)
|
| 84 |
+
preds.append({
|
| 85 |
+
"model_answer": ans,
|
| 86 |
+
"predicted": float(re.findall(r"[\d,]+\.?\d*", ans.replace(",", ""))[0]) if re.findall(r"[\d,]+\.?\d*", ans.replace(",", "")) else -1.0,
|
| 87 |
+
"correct_answer": answer,
|
| 88 |
+
"correct": correct,
|
| 89 |
+
"depth": depth,
|
| 90 |
+
})
|
| 91 |
+
|
| 92 |
+
save_jsonl(os.path.join(out_dir, f"reason_depth_{depth}.jsonl"), preds)
|
| 93 |
+
acc = compute_accuracy(preds)
|
| 94 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 95 |
+
logger.info(f"[REASON] Depth {depth:.1%}: acc={acc:.3f}")
|
| 96 |
+
|
| 97 |
+
summary = {
|
| 98 |
+
"experiment": "fact_reasoning",
|
| 99 |
+
"num_sentences": num_sentences,
|
| 100 |
+
"num_examples": num_examples,
|
| 101 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 102 |
+
"pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
|
| 103 |
+
"time_minutes": (time.time() - start) / 60,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
save_json(os.path.join(out_dir, "reason_summary.json"), summary)
|
| 107 |
+
plot_curve(
|
| 108 |
+
depths,
|
| 109 |
+
[results[d]["accuracy"] for d in depths],
|
| 110 |
+
f"Exp 4: Fact-Dependent Reasoning ({num_sentences} sentences)",
|
| 111 |
+
os.path.join(out_dir, "reason_curve.png"),
|
| 112 |
+
xlabel="Depth in Document (0=start, 1=end)",
|
| 113 |
+
ylabel="Problem-Solving Accuracy",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
logger.info(f"[REASON] Time={(time.time()-start)/60:.1f} min")
|
| 117 |
+
return summary
|
experiments/kv_retrieval.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment 1: Key-Value Retrieval
|
| 3 |
+
Replicates Liu et al. (2023) with expanded position granularity.
|
| 4 |
+
Generates UUID key-value pairs, places gold pair at controlled depths.
|
| 5 |
+
"""
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import time
|
| 11 |
+
import uuid
|
| 12 |
+
from typing import List, Dict, Any
|
| 13 |
+
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from src.generator import generate_text
|
| 17 |
+
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
|
| 18 |
+
from src.plotting import plot_curve
|
| 19 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _gen_kv_data(num_keys: int, num_examples: int) -> List[Dict[str, Any]]:
|
| 25 |
+
"""Generate key-value pair examples."""
|
| 26 |
+
examples = []
|
| 27 |
+
for _ in tqdm(range(num_examples), desc=f"Gen KV data ({num_keys} keys)"):
|
| 28 |
+
kv = {}
|
| 29 |
+
while len(kv) != num_keys:
|
| 30 |
+
kv[str(uuid.uuid4())] = str(uuid.uuid4())
|
| 31 |
+
ordered = list(kv.items())
|
| 32 |
+
gold = random.choice(ordered)
|
| 33 |
+
examples.append({"ordered_kv_records": ordered, "key": gold[0], "value": gold[1]})
|
| 34 |
+
return examples
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _format_prompt(data: List[tuple], key: str) -> str:
|
| 38 |
+
"""Format KV data into prompt template."""
|
| 39 |
+
template = """Extract the value corresponding to the specified key in the JSON object below.
|
| 40 |
+
|
| 41 |
+
JSON data:
|
| 42 |
+
{formatted}
|
| 43 |
+
|
| 44 |
+
Key: "{key}"
|
| 45 |
+
Corresponding value:"""
|
| 46 |
+
formatted = ""
|
| 47 |
+
for i, (k, v) in enumerate(data):
|
| 48 |
+
sc = "{" if i == 0 else " "
|
| 49 |
+
ec = ",\n" if i != len(data) - 1 else "}"
|
| 50 |
+
formatted += sc + f'"{k}": "{v}"' + ec
|
| 51 |
+
return template.format(formatted=formatted, key=key)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _reorder(example: Dict[str, Any], gold_pos: int) -> Dict[str, Any]:
|
| 55 |
+
"""Move gold pair to specified position."""
|
| 56 |
+
ordered = example["ordered_kv_records"]
|
| 57 |
+
key = example["key"]
|
| 58 |
+
value = example["value"]
|
| 59 |
+
gi = next(i for i, (k, v) in enumerate(ordered) if k == key)
|
| 60 |
+
new = ordered[:gi] + ordered[gi + 1:]
|
| 61 |
+
new = new[:gold_pos] + [(key, value)] + new[gold_pos:]
|
| 62 |
+
return {"ordered_kv_records": new, "key": key, "value": value}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def run_kv_retrieval(
|
| 66 |
+
model_name: str,
|
| 67 |
+
num_keys: int,
|
| 68 |
+
num_examples: int,
|
| 69 |
+
out_dir: str,
|
| 70 |
+
positions: List[int] = None,
|
| 71 |
+
prefix: str = "kv",
|
| 72 |
+
) -> Dict[str, Any]:
|
| 73 |
+
"""
|
| 74 |
+
Run KV retrieval experiment.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
model_name: HF model identifier
|
| 78 |
+
num_keys: Number of KV pairs
|
| 79 |
+
num_examples: Examples per position
|
| 80 |
+
out_dir: Output directory
|
| 81 |
+
positions: Custom position list (default: 9 positions)
|
| 82 |
+
prefix: Filename prefix
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Summary dict with accuracy per position and PBI
|
| 86 |
+
"""
|
| 87 |
+
ensure_dir(out_dir)
|
| 88 |
+
|
| 89 |
+
if positions is None:
|
| 90 |
+
positions = sorted(set([
|
| 91 |
+
0,
|
| 92 |
+
num_keys // 8,
|
| 93 |
+
num_keys // 4,
|
| 94 |
+
3 * num_keys // 8,
|
| 95 |
+
num_keys // 2,
|
| 96 |
+
5 * num_keys // 8,
|
| 97 |
+
3 * num_keys // 4,
|
| 98 |
+
7 * num_keys // 8,
|
| 99 |
+
num_keys - 1,
|
| 100 |
+
]))
|
| 101 |
+
|
| 102 |
+
# Generate data once, then reorder for each position
|
| 103 |
+
data_path = os.path.join(out_dir, f"{prefix}_data.jsonl")
|
| 104 |
+
examples = _gen_kv_data(num_keys, num_examples)
|
| 105 |
+
save_jsonl(data_path, examples)
|
| 106 |
+
|
| 107 |
+
results = {}
|
| 108 |
+
start = time.time()
|
| 109 |
+
|
| 110 |
+
for pos in positions:
|
| 111 |
+
logger.info(f"[{prefix}] Position {pos}/{num_keys - 1}")
|
| 112 |
+
preds = []
|
| 113 |
+
for ex in tqdm(examples, desc=f"{prefix} pos={pos}", leave=False):
|
| 114 |
+
ro = _reorder(ex, pos)
|
| 115 |
+
prompt = _format_prompt(ro["ordered_kv_records"], ro["key"])
|
| 116 |
+
ans = generate_text(
|
| 117 |
+
[{"role": "user", "content": prompt}],
|
| 118 |
+
model_name=model_name,
|
| 119 |
+
max_new_tokens=80,
|
| 120 |
+
)
|
| 121 |
+
correct = exact_match_score(ans, ro["value"])
|
| 122 |
+
preds.append({
|
| 123 |
+
"model_answer": ans,
|
| 124 |
+
"correct": correct,
|
| 125 |
+
"value": ro["value"],
|
| 126 |
+
"gold_position": pos,
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
save_jsonl(os.path.join(out_dir, f"{prefix}_pos_{pos}.jsonl"), preds)
|
| 130 |
+
acc = compute_accuracy(preds)
|
| 131 |
+
results[pos] = {"accuracy": acc, "predictions": preds}
|
| 132 |
+
logger.info(f"[{prefix}] Pos {pos}: acc={acc:.3f}")
|
| 133 |
+
|
| 134 |
+
# Summary
|
| 135 |
+
norm_pos = [p / (num_keys - 1) for p in sorted(results.keys())]
|
| 136 |
+
accs = [results[p]["accuracy"] for p in sorted(results.keys())]
|
| 137 |
+
pbi = position_bias_index(norm_pos, accs)
|
| 138 |
+
|
| 139 |
+
summary = {
|
| 140 |
+
"experiment": "kv_retrieval",
|
| 141 |
+
"num_keys": num_keys,
|
| 142 |
+
"num_examples": num_examples,
|
| 143 |
+
"positions": {str(p): results[p]["accuracy"] for p in sorted(results.keys())},
|
| 144 |
+
"pbi": pbi,
|
| 145 |
+
"time_minutes": (time.time() - start) / 60,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
save_json(os.path.join(out_dir, f"{prefix}_summary.json"), summary)
|
| 149 |
+
plot_curve(
|
| 150 |
+
norm_pos, accs,
|
| 151 |
+
f"Exp 1: KV Retrieval ({num_keys} keys)",
|
| 152 |
+
os.path.join(out_dir, f"{prefix}_curve.png"),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
logger.info(f"[{prefix}] PBI={pbi:.3f} | Time={(time.time()-start)/60:.1f} min")
|
| 156 |
+
return summary
|
experiments/multi_needle.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment 3: Multi-Needle Retrieval
|
| 3 |
+
Tests ability to retrieve ALL of multiple needles placed at start, middle, and end.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from typing import Dict, Any
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from src.generator import generate_text
|
| 14 |
+
from src.metrics import exact_match_score, compute_accuracy
|
| 15 |
+
from src.plotting import plot_bar
|
| 16 |
+
from src.utils import ensure_dir, save_json
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
from .needle_in_haystack import FILLERS
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _make_haystack(n: int) -> str:
|
| 24 |
+
return " ".join(random.choice(FILLERS) + f" [{i+1}]." for i in range(n))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def run_multi_needle(
|
| 28 |
+
model_name: str,
|
| 29 |
+
num_sentences: int,
|
| 30 |
+
num_examples: int,
|
| 31 |
+
out_dir: str,
|
| 32 |
+
) -> Dict[str, Any]:
|
| 33 |
+
"""Run multi-needle experiment."""
|
| 34 |
+
ensure_dir(out_dir)
|
| 35 |
+
|
| 36 |
+
start = time.time()
|
| 37 |
+
start_ok, mid_ok, end_ok = [], [], []
|
| 38 |
+
|
| 39 |
+
for i in tqdm(range(num_examples), desc="Multi-needle"):
|
| 40 |
+
filler = _make_haystack(num_sentences)
|
| 41 |
+
sents = [s.strip() + "." for s in filler.split(".") if s.strip()]
|
| 42 |
+
n = len(sents)
|
| 43 |
+
ca, cb, cc = f"ALPHA-{i:03d}", f"BETA-{i:03d}", f"GAMMA-{i:03d}"
|
| 44 |
+
|
| 45 |
+
sents.insert(0, f"The first secret code is {ca}.")
|
| 46 |
+
sents.insert(n // 2, f"The second secret code is {cb}.")
|
| 47 |
+
sents.append(f"The third secret code is {cc}.")
|
| 48 |
+
|
| 49 |
+
prompt = (
|
| 50 |
+
f"Read the text and list ALL three secret codes in order.\n\n"
|
| 51 |
+
f"{' '.join(sents)}\n\nCodes:"
|
| 52 |
+
)
|
| 53 |
+
ans = generate_text(
|
| 54 |
+
[{"role": "user", "content": prompt}],
|
| 55 |
+
model_name=model_name,
|
| 56 |
+
max_new_tokens=60,
|
| 57 |
+
)
|
| 58 |
+
start_ok.append(exact_match_score(ans, ca))
|
| 59 |
+
mid_ok.append(exact_match_score(ans, cb))
|
| 60 |
+
end_ok.append(exact_match_score(ans, cc))
|
| 61 |
+
|
| 62 |
+
summary = {
|
| 63 |
+
"experiment": "multi_needle",
|
| 64 |
+
"num_sentences": num_sentences,
|
| 65 |
+
"num_examples": num_examples,
|
| 66 |
+
"start": compute_accuracy([{"correct": c} for c in start_ok]),
|
| 67 |
+
"middle": compute_accuracy([{"correct": c} for c in mid_ok]),
|
| 68 |
+
"end": compute_accuracy([{"correct": c} for c in end_ok]),
|
| 69 |
+
"time_minutes": (time.time() - start) / 60,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
logger.info(
|
| 73 |
+
f"[MULTI] Start={summary['start']:.3f} Mid={summary['middle']:.3f} End={summary['end']:.3f}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
save_json(os.path.join(out_dir, "multi_summary.json"), summary)
|
| 77 |
+
plot_bar(
|
| 78 |
+
["Start", "Middle", "End"],
|
| 79 |
+
[summary["start"], summary["middle"], summary["end"]],
|
| 80 |
+
f"Exp 3: Multi-Needle (n={num_examples})",
|
| 81 |
+
os.path.join(out_dir, "multi_bar.png"),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return summary
|
experiments/needle_in_haystack.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment 2: Needle in Haystack (text)
|
| 3 |
+
Tests retrieval of a secret code hidden at varying depths in filler text.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from src.generator import generate_text
|
| 14 |
+
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
|
| 15 |
+
from src.plotting import plot_curve
|
| 16 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
FILLERS = [
|
| 21 |
+
"The history of pottery spans thousands of years.",
|
| 22 |
+
"Marine biologists study coral reef ecosystems.",
|
| 23 |
+
"Railway engineering requires precise curvature.",
|
| 24 |
+
"The periodic table arranges elements by number.",
|
| 25 |
+
"Clouds are classified as cumulus and stratus.",
|
| 26 |
+
"Beekeeping traditions differ between continents.",
|
| 27 |
+
"The Great Wall was built over many dynasties.",
|
| 28 |
+
"Thermodynamics governs heat transfer.",
|
| 29 |
+
"Impressionist painters captured fleeting light.",
|
| 30 |
+
"Volcanic activity is tracked with seismographs.",
|
| 31 |
+
"The Dewey Decimal System organizes libraries.",
|
| 32 |
+
"Irrigation evolved from canals to drip systems.",
|
| 33 |
+
"Neural networks are inspired by biological brains.",
|
| 34 |
+
"Light speed is 299,792,458 meters per second.",
|
| 35 |
+
"Classical composition follows harmonic rules.",
|
| 36 |
+
"Urban planning addresses zoning and transport.",
|
| 37 |
+
"Photosynthesis converts CO2 into glucose.",
|
| 38 |
+
"The Fibonacci sequence appears in nature.",
|
| 39 |
+
"GPS uses triangulation from satellites.",
|
| 40 |
+
"Cryptography secures digital communication.",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _make_haystack(n: int) -> str:
|
| 45 |
+
"""Generate n sentences of filler text."""
|
| 46 |
+
return " ".join(random.choice(FILLERS) + f" [{i+1}]." for i in range(n))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _insert_needle(text: str, needle: str, ratio: float) -> str:
|
| 50 |
+
"""Insert needle at specified depth ratio."""
|
| 51 |
+
sents = [s.strip() + "." for s in text.split(".") if s.strip()]
|
| 52 |
+
idx = int(ratio * len(sents))
|
| 53 |
+
sents.insert(idx, needle)
|
| 54 |
+
return " ".join(sents)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def run_needle_in_haystack(
|
| 58 |
+
model_name: str,
|
| 59 |
+
num_sentences: int,
|
| 60 |
+
num_examples: int,
|
| 61 |
+
out_dir: str,
|
| 62 |
+
depths: List[float] = None,
|
| 63 |
+
) -> Dict[str, Any]:
|
| 64 |
+
"""Run needle-in-haystack experiment."""
|
| 65 |
+
ensure_dir(out_dir)
|
| 66 |
+
|
| 67 |
+
if depths is None:
|
| 68 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 69 |
+
|
| 70 |
+
results = {}
|
| 71 |
+
start = time.time()
|
| 72 |
+
|
| 73 |
+
for depth in depths:
|
| 74 |
+
logger.info(f"[NEEDLE] Depth {depth:.1%}")
|
| 75 |
+
preds = []
|
| 76 |
+
for i in tqdm(range(num_examples), desc=f"Needle {depth:.1%}", leave=False):
|
| 77 |
+
filler = _make_haystack(num_sentences)
|
| 78 |
+
code = f"SECRET-{i:04d}"
|
| 79 |
+
needle = f"The secret code is {code}."
|
| 80 |
+
text = _insert_needle(filler, needle, depth)
|
| 81 |
+
prompt = (
|
| 82 |
+
f"Read the text and find the secret code.\n\n{text}\n\n"
|
| 83 |
+
f"What is the secret code? Answer with only the code."
|
| 84 |
+
)
|
| 85 |
+
ans = generate_text(
|
| 86 |
+
[{"role": "user", "content": prompt}],
|
| 87 |
+
model_name=model_name,
|
| 88 |
+
max_new_tokens=20,
|
| 89 |
+
)
|
| 90 |
+
correct = exact_match_score(ans, code)
|
| 91 |
+
preds.append({
|
| 92 |
+
"model_answer": ans,
|
| 93 |
+
"correct": correct,
|
| 94 |
+
"secret": code,
|
| 95 |
+
"depth": depth,
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
save_jsonl(os.path.join(out_dir, f"needle_depth_{depth}.jsonl"), preds)
|
| 99 |
+
acc = compute_accuracy(preds)
|
| 100 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 101 |
+
logger.info(f"[NEEDLE] Depth {depth:.1%}: acc={acc:.3f}")
|
| 102 |
+
|
| 103 |
+
summary = {
|
| 104 |
+
"experiment": "needle_in_haystack",
|
| 105 |
+
"num_sentences": num_sentences,
|
| 106 |
+
"num_examples": num_examples,
|
| 107 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 108 |
+
"pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
|
| 109 |
+
"time_minutes": (time.time() - start) / 60,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
save_json(os.path.join(out_dir, "needle_summary.json"), summary)
|
| 113 |
+
plot_curve(
|
| 114 |
+
depths,
|
| 115 |
+
[results[d]["accuracy"] for d in depths],
|
| 116 |
+
f"Exp 2: Needle in Haystack ({num_sentences} sentences)",
|
| 117 |
+
os.path.join(out_dir, "needle_curve.png"),
|
| 118 |
+
xlabel="Depth in Document (0=start, 1=end)",
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
logger.info(f"[NEEDLE] Time={(time.time()-start)/60:.1f} min")
|
| 122 |
+
return summary
|
experiments/semantic_distractors.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment 5: Semantic Similarity Distractors
|
| 3 |
+
Gold fact ("capital of France is Paris") among semantically similar facts.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from src.generator import generate_text
|
| 14 |
+
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
|
| 15 |
+
from src.plotting import plot_curve
|
| 16 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
TEMPLATES = [
|
| 21 |
+
"The capital of {country} is {city}.",
|
| 22 |
+
"The population of {country} is approximately {num} million.",
|
| 23 |
+
"The official language of {country} is {lang}.",
|
| 24 |
+
"The currency of {country} is the {currency}.",
|
| 25 |
+
"The largest city in {country} is {city}.",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
COUNTRIES = [
|
| 29 |
+
"Germany", "Spain", "Italy", "Brazil", "Argentina", "Canada",
|
| 30 |
+
"Australia", "Japan", "China", "India", "Russia", "Egypt",
|
| 31 |
+
"Turkey", "Mexico", "South Korea", "Thailand", "Vietnam",
|
| 32 |
+
"Poland", "Sweden", "Norway", "Denmark", "Finland", "Greece",
|
| 33 |
+
"Portugal", "Ireland", "Austria", "Switzerland", "Belgium",
|
| 34 |
+
"Netherlands", "Czech Republic", "Hungary", "Romania",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
CITIES = [
|
| 38 |
+
"Berlin", "Madrid", "Rome", "Brasilia", "Buenos Aires", "Ottawa",
|
| 39 |
+
"Canberra", "Tokyo", "Beijing", "New Delhi", "Moscow", "Cairo",
|
| 40 |
+
"Ankara", "Mexico City", "Seoul", "Bangkok", "Hanoi",
|
| 41 |
+
"Warsaw", "Stockholm", "Oslo", "Copenhagen", "Helsinki", "Athens",
|
| 42 |
+
"Lisbon", "Dublin", "Vienna", "Bern", "Brussels",
|
| 43 |
+
"Amsterdam", "Prague", "Budapest", "Bucharest",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
LANGS = [
|
| 47 |
+
"German", "Spanish", "Italian", "Portuguese", "French",
|
| 48 |
+
"English", "Japanese", "Mandarin", "Hindi", "Russian",
|
| 49 |
+
"Arabic", "Turkish", "Korean", "Thai", "Vietnamese",
|
| 50 |
+
"Polish", "Swedish", "Norwegian", "Danish", "Finnish",
|
| 51 |
+
"Greek", "Irish", "Dutch", "Czech", "Hungarian", "Romanian",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
CURRENCIES = [
|
| 55 |
+
"Euro", "Peso", "Real", "Dollar", "Yen", "Yuan", "Rupee",
|
| 56 |
+
"Ruble", "Pound", "Won", "Baht", "Dong", "Zloty",
|
| 57 |
+
"Krone", "Krona", "Forint", "Leu", "Franc",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _make_doc(num_facts: int, gold_fact: str, ratio: float) -> str:
|
| 62 |
+
facts = []
|
| 63 |
+
for _ in range(num_facts):
|
| 64 |
+
t = random.choice(TEMPLATES)
|
| 65 |
+
fact = t.format(
|
| 66 |
+
country=random.choice(COUNTRIES),
|
| 67 |
+
city=random.choice(CITIES),
|
| 68 |
+
num=random.randint(10, 1400),
|
| 69 |
+
lang=random.choice(LANGS),
|
| 70 |
+
currency=random.choice(CURRENCIES),
|
| 71 |
+
)
|
| 72 |
+
facts.append(fact)
|
| 73 |
+
|
| 74 |
+
idx = int(ratio * len(facts))
|
| 75 |
+
facts.insert(idx, gold_fact)
|
| 76 |
+
return "\n".join(f"{i+1}. {f}" for i, f in enumerate(facts))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def run_semantic_distractors(
|
| 80 |
+
model_name: str,
|
| 81 |
+
num_facts: int,
|
| 82 |
+
num_examples: int,
|
| 83 |
+
out_dir: str,
|
| 84 |
+
depths: List[float] = None,
|
| 85 |
+
) -> Dict[str, Any]:
|
| 86 |
+
"""Run semantic distractor experiment."""
|
| 87 |
+
ensure_dir(out_dir)
|
| 88 |
+
|
| 89 |
+
if depths is None:
|
| 90 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 91 |
+
|
| 92 |
+
results = {}
|
| 93 |
+
start = time.time()
|
| 94 |
+
|
| 95 |
+
for depth in depths:
|
| 96 |
+
logger.info(f"[SEMANTIC] Depth {depth:.1%}")
|
| 97 |
+
preds = []
|
| 98 |
+
for i in tqdm(range(num_examples), desc=f"Semantic {depth:.1%}", leave=False):
|
| 99 |
+
gold = "The capital of France is Paris."
|
| 100 |
+
doc = _make_doc(num_facts, gold, depth)
|
| 101 |
+
prompt = (
|
| 102 |
+
f"Read the following list of facts and answer the question.\n\n{doc}\n\n"
|
| 103 |
+
f"Question: What is the capital of France? Answer with only the city name."
|
| 104 |
+
)
|
| 105 |
+
ans = generate_text(
|
| 106 |
+
[{"role": "user", "content": prompt}],
|
| 107 |
+
model_name=model_name,
|
| 108 |
+
max_new_tokens=20,
|
| 109 |
+
)
|
| 110 |
+
correct = exact_match_score(ans, "paris")
|
| 111 |
+
preds.append({
|
| 112 |
+
"model_answer": ans,
|
| 113 |
+
"correct": correct,
|
| 114 |
+
"depth": depth,
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
save_jsonl(os.path.join(out_dir, f"semantic_depth_{depth}.jsonl"), preds)
|
| 118 |
+
acc = compute_accuracy(preds)
|
| 119 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 120 |
+
logger.info(f"[SEMANTIC] Depth {depth:.1%}: acc={acc:.3f}")
|
| 121 |
+
|
| 122 |
+
summary = {
|
| 123 |
+
"experiment": "semantic_distractors",
|
| 124 |
+
"num_facts": num_facts,
|
| 125 |
+
"num_examples": num_examples,
|
| 126 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 127 |
+
"pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
|
| 128 |
+
"time_minutes": (time.time() - start) / 60,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
save_json(os.path.join(out_dir, "semantic_summary.json"), summary)
|
| 132 |
+
plot_curve(
|
| 133 |
+
depths,
|
| 134 |
+
[results[d]["accuracy"] for d in depths],
|
| 135 |
+
f"Exp 5: Semantic Similarity Distractors ({num_facts} facts)",
|
| 136 |
+
os.path.join(out_dir, "semantic_curve.png"),
|
| 137 |
+
xlabel="Depth in Document (0=start, 1=end)",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
logger.info(f"[SEMANTIC] Time={(time.time()-start)/60:.1f} min")
|
| 141 |
+
return summary
|
experiments/temporal_narrative.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Experiment 6: Temporal Narrative
|
| 3 |
+
Recall an event from a long chronological timeline.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
from typing import List, Dict, Any
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from src.generator import generate_text
|
| 15 |
+
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
|
| 16 |
+
from src.plotting import plot_curve
|
| 17 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
EVENTS_POOL = [
|
| 22 |
+
"the king issued a decree",
|
| 23 |
+
"a comet appeared in the sky",
|
| 24 |
+
"the bridge was completed",
|
| 25 |
+
"a treaty was signed",
|
| 26 |
+
"the harvest festival began",
|
| 27 |
+
"a stranger arrived at the gates",
|
| 28 |
+
"the library burned down",
|
| 29 |
+
"a new star was discovered",
|
| 30 |
+
"the river flooded the town",
|
| 31 |
+
"the army marched north",
|
| 32 |
+
"a peace envoy was sent",
|
| 33 |
+
"the market was opened",
|
| 34 |
+
"a plague swept the city",
|
| 35 |
+
"the old temple was restored",
|
| 36 |
+
"a fleet set sail for distant lands",
|
| 37 |
+
"the academy admitted its first students",
|
| 38 |
+
"a rebellion broke out in the east",
|
| 39 |
+
"the queen gave birth to twins",
|
| 40 |
+
"a dragon was spotted in the mountains",
|
| 41 |
+
"the great bell tolled for the first time",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _make_timeline(num_events: int, target_event: str, ratio: float) -> str:
|
| 46 |
+
events = random.sample(EVENTS_POOL, min(num_events, len(EVENTS_POOL)))
|
| 47 |
+
while len(events) < num_events:
|
| 48 |
+
events.append(
|
| 49 |
+
f"the people gathered for the {random.choice(['morning', 'evening', 'midday'])} ceremony"
|
| 50 |
+
)
|
| 51 |
+
idx = int(ratio * len(events))
|
| 52 |
+
events.insert(idx, target_event)
|
| 53 |
+
return "\n".join(f"Year {1000+i}: {e}." for i, e in enumerate(events))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def run_temporal_narrative(
|
| 57 |
+
model_name: str,
|
| 58 |
+
num_events: int,
|
| 59 |
+
num_examples: int,
|
| 60 |
+
out_dir: str,
|
| 61 |
+
depths: List[float] = None,
|
| 62 |
+
) -> Dict[str, Any]:
|
| 63 |
+
"""Run temporal narrative experiment."""
|
| 64 |
+
ensure_dir(out_dir)
|
| 65 |
+
|
| 66 |
+
if depths is None:
|
| 67 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 68 |
+
|
| 69 |
+
results = {}
|
| 70 |
+
start = time.time()
|
| 71 |
+
|
| 72 |
+
for depth in depths:
|
| 73 |
+
logger.info(f"[NARRATIVE] Depth {depth:.1%}")
|
| 74 |
+
preds = []
|
| 75 |
+
for i in tqdm(range(num_examples), desc=f"Narrative {depth:.1%}", leave=False):
|
| 76 |
+
target = "a golden statue was unveiled in the central square"
|
| 77 |
+
timeline = _make_timeline(num_events, target, depth)
|
| 78 |
+
prompt = (
|
| 79 |
+
f"Read the following timeline of historical events.\n\n{timeline}\n\n"
|
| 80 |
+
f"Question: In which year was a golden statue unveiled in the central square? "
|
| 81 |
+
f"Answer with only the year number."
|
| 82 |
+
)
|
| 83 |
+
ans = generate_text(
|
| 84 |
+
[{"role": "user", "content": prompt}],
|
| 85 |
+
model_name=model_name,
|
| 86 |
+
max_new_tokens=15,
|
| 87 |
+
)
|
| 88 |
+
expected_year = 1000 + int(depth * num_events)
|
| 89 |
+
years = re.findall(r"\b\d{4}\b", ans)
|
| 90 |
+
correct = 1.0 if any(abs(int(y) - expected_year) < 5 for y in years) else 0.0
|
| 91 |
+
preds.append({
|
| 92 |
+
"model_answer": ans,
|
| 93 |
+
"correct": correct,
|
| 94 |
+
"expected_year": expected_year,
|
| 95 |
+
"depth": depth,
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
save_jsonl(os.path.join(out_dir, f"narrative_depth_{depth}.jsonl"), preds)
|
| 99 |
+
acc = compute_accuracy(preds)
|
| 100 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 101 |
+
logger.info(f"[NARRATIVE] Depth {depth:.1%}: acc={acc:.3f}")
|
| 102 |
+
|
| 103 |
+
summary = {
|
| 104 |
+
"experiment": "temporal_narrative",
|
| 105 |
+
"num_events": num_events,
|
| 106 |
+
"num_examples": num_examples,
|
| 107 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 108 |
+
"pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
|
| 109 |
+
"time_minutes": (time.time() - start) / 60,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
save_json(os.path.join(out_dir, "narrative_summary.json"), summary)
|
| 113 |
+
plot_curve(
|
| 114 |
+
depths,
|
| 115 |
+
[results[d]["accuracy"] for d in depths],
|
| 116 |
+
f"Exp 6: Temporal Narrative ({num_events} events)",
|
| 117 |
+
os.path.join(out_dir, "narrative_curve.png"),
|
| 118 |
+
xlabel="Depth in Timeline (0=start, 1=end)",
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
logger.info(f"[NARRATIVE] Time={(time.time()-start)/60:.1f} min")
|
| 122 |
+
return summary
|
run_all.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
================================================================================
|
| 4 |
+
LOST IN THE MIDDLE — Benchmark Suite v4 (Master Runner)
|
| 5 |
+
================================================================================
|
| 6 |
+
Runs all 7 experiments with configurable model, counts, and output directory.
|
| 7 |
+
Usage:
|
| 8 |
+
python run_all.py --model Qwen/Qwen2.5-1.5B-Instruct --output ./results
|
| 9 |
+
================================================================================
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import shutil
|
| 16 |
+
import sys
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
from experiments.kv_retrieval import run_kv_retrieval
|
| 20 |
+
from experiments.needle_in_haystack import run_needle_in_haystack
|
| 21 |
+
from experiments.multi_needle import run_multi_needle
|
| 22 |
+
from experiments.fact_reasoning import run_fact_reasoning
|
| 23 |
+
from experiments.semantic_distractors import run_semantic_distractors
|
| 24 |
+
from experiments.temporal_narrative import run_temporal_narrative
|
| 25 |
+
from experiments.conversation_memory import run_conversation_memory
|
| 26 |
+
from src.utils import save_json
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(
|
| 29 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 30 |
+
level=logging.INFO,
|
| 31 |
+
stream=sys.stdout,
|
| 32 |
+
)
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def parse_args():
|
| 37 |
+
p = argparse.ArgumentParser(description="LITM Benchmark Suite v4")
|
| 38 |
+
p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct", help="HF model name")
|
| 39 |
+
p.add_argument("--output", default="./results", help="Output directory")
|
| 40 |
+
p.add_argument("--n-examples", type=int, default=50, help="Examples per position")
|
| 41 |
+
p.add_argument("--n-keys-100", type=int, default=100)
|
| 42 |
+
p.add_argument("--n-keys-200", type=int, default=200)
|
| 43 |
+
p.add_argument("--needle-sentences", type=int, default=500)
|
| 44 |
+
p.add_argument("--multi-sentences", type=int, default=300)
|
| 45 |
+
p.add_argument("--reason-sentences", type=int, default=300)
|
| 46 |
+
p.add_argument("--semantic-facts", type=int, default=80)
|
| 47 |
+
p.add_argument("--narrative-events", type=int, default=100)
|
| 48 |
+
p.add_argument("--convo-turns", type=int, default=100)
|
| 49 |
+
p.add_argument("--experiments", default="all", help="Comma-separated list or 'all'")
|
| 50 |
+
p.add_argument("--zip", action="store_true", help="Create zip archive of results")
|
| 51 |
+
return p.parse_args()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
args = parse_args()
|
| 56 |
+
model = args.model
|
| 57 |
+
out_root = args.output
|
| 58 |
+
os.makedirs(out_root, exist_ok=True)
|
| 59 |
+
|
| 60 |
+
wanted = set(args.experiments.split(",")) if args.experiments != "all" else {"all"}
|
| 61 |
+
|
| 62 |
+
logger.info("=" * 70)
|
| 63 |
+
logger.info("LITM BENCHMARK SUITE v4")
|
| 64 |
+
logger.info(f"Model: {model} | Output: {out_root}")
|
| 65 |
+
logger.info("=" * 70)
|
| 66 |
+
|
| 67 |
+
all_results = {}
|
| 68 |
+
t0 = time.time()
|
| 69 |
+
|
| 70 |
+
def should_run(name):
|
| 71 |
+
return "all" in wanted or name in wanted
|
| 72 |
+
|
| 73 |
+
if should_run("kv100"):
|
| 74 |
+
logger.info("\n--- EXP 1A: KV Retrieval (100 keys) ---")
|
| 75 |
+
all_results["kv_100"] = run_kv_retrieval(
|
| 76 |
+
model_name=model,
|
| 77 |
+
num_keys=args.n_keys_100,
|
| 78 |
+
num_examples=args.n_examples,
|
| 79 |
+
out_dir=os.path.join(out_root, "exp1a_kv100"),
|
| 80 |
+
prefix="kv100",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if should_run("kv200"):
|
| 84 |
+
logger.info("\n--- EXP 1B: KV Retrieval (200 keys) ---")
|
| 85 |
+
all_results["kv_200"] = run_kv_retrieval(
|
| 86 |
+
model_name=model,
|
| 87 |
+
num_keys=args.n_keys_200,
|
| 88 |
+
num_examples=args.n_examples,
|
| 89 |
+
out_dir=os.path.join(out_root, "exp1b_kv200"),
|
| 90 |
+
prefix="kv200",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if should_run("needle"):
|
| 94 |
+
logger.info("\n--- EXP 2: Needle in Haystack ---")
|
| 95 |
+
all_results["needle"] = run_needle_in_haystack(
|
| 96 |
+
model_name=model,
|
| 97 |
+
num_sentences=args.needle_sentences,
|
| 98 |
+
num_examples=30,
|
| 99 |
+
out_dir=os.path.join(out_root, "exp2_needle"),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if should_run("multi"):
|
| 103 |
+
logger.info("\n--- EXP 3: Multi-Needle ---")
|
| 104 |
+
all_results["multi"] = run_multi_needle(
|
| 105 |
+
model_name=model,
|
| 106 |
+
num_sentences=args.multi_sentences,
|
| 107 |
+
num_examples=30,
|
| 108 |
+
out_dir=os.path.join(out_root, "exp3_multi"),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if should_run("reason"):
|
| 112 |
+
logger.info("\n--- EXP 4: Fact-Dependent Reasoning ---")
|
| 113 |
+
all_results["reason"] = run_fact_reasoning(
|
| 114 |
+
model_name=model,
|
| 115 |
+
num_sentences=args.reason_sentences,
|
| 116 |
+
num_examples=30,
|
| 117 |
+
out_dir=os.path.join(out_root, "exp4_reason"),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if should_run("semantic"):
|
| 121 |
+
logger.info("\n--- EXP 5: Semantic Similarity Distractors ---")
|
| 122 |
+
all_results["semantic"] = run_semantic_distractors(
|
| 123 |
+
model_name=model,
|
| 124 |
+
num_facts=args.semantic_facts,
|
| 125 |
+
num_examples=30,
|
| 126 |
+
out_dir=os.path.join(out_root, "exp5_semantic"),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if should_run("narrative"):
|
| 130 |
+
logger.info("\n--- EXP 6: Temporal Narrative ---")
|
| 131 |
+
all_results["narrative"] = run_temporal_narrative(
|
| 132 |
+
model_name=model,
|
| 133 |
+
num_events=args.narrative_events,
|
| 134 |
+
num_examples=30,
|
| 135 |
+
out_dir=os.path.join(out_root, "exp6_narrative"),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if should_run("conversation"):
|
| 139 |
+
logger.info("\n--- EXP 7: Conversation Memory ---")
|
| 140 |
+
all_results["conversation"] = run_conversation_memory(
|
| 141 |
+
model_name=model,
|
| 142 |
+
num_turns=args.convo_turns,
|
| 143 |
+
num_examples=30,
|
| 144 |
+
out_dir=os.path.join(out_root, "exp7_conversation"),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
elapsed = (time.time() - t0) / 3600
|
| 148 |
+
logger.info(f"\n{'='*70}")
|
| 149 |
+
logger.info(f"COMPLETE. Total time: {elapsed:.2f} hours")
|
| 150 |
+
logger.info(f"Results: {out_root}")
|
| 151 |
+
logger.info(f"{'='*70}")
|
| 152 |
+
|
| 153 |
+
save_json(os.path.join(out_root, "master_summary.json"), all_results)
|
| 154 |
+
|
| 155 |
+
# Print PBI table
|
| 156 |
+
logger.info("\n--- Position Bias Index (PBI) Summary ---")
|
| 157 |
+
for k, v in all_results.items():
|
| 158 |
+
if isinstance(v, dict) and "pbi" in v:
|
| 159 |
+
logger.info(f" {k:20s} PBI = {v['pbi']:+.3f}")
|
| 160 |
+
|
| 161 |
+
if args.zip:
|
| 162 |
+
zip_path = os.path.join(os.path.dirname(out_root), "litm_results_all")
|
| 163 |
+
shutil.make_archive(zip_path, "zip", out_root)
|
| 164 |
+
logger.info(f"Zipped: {zip_path}.zip")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
main()
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LITM Benchmark Suite v4 — Core Library"""
|
| 2 |
+
__version__ = "4.0.0"
|
| 3 |
+
__author__ = "abhshkp"
|
src/generator.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text generation wrapper with chat-template support."""
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
from .model_loader import load_model
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def generate_text(
|
| 11 |
+
messages: List[Dict[str, str]],
|
| 12 |
+
model_name: str,
|
| 13 |
+
max_new_tokens: int = 80,
|
| 14 |
+
load_in_4bit: bool = True,
|
| 15 |
+
) -> str:
|
| 16 |
+
"""Generate text from a chat-formatted message list."""
|
| 17 |
+
model, tokenizer = load_model(model_name, load_in_4bit=load_in_4bit)
|
| 18 |
+
|
| 19 |
+
inputs = tokenizer.apply_chat_template(
|
| 20 |
+
messages,
|
| 21 |
+
tokenize=True,
|
| 22 |
+
return_tensors="pt",
|
| 23 |
+
add_generation_prompt=True,
|
| 24 |
+
return_dict=True,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
dev = next(model.parameters()).device
|
| 28 |
+
inputs = {k: v.to(dev) for k, v in inputs.items()}
|
| 29 |
+
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
outputs = model.generate(
|
| 32 |
+
**inputs,
|
| 33 |
+
max_new_tokens=max_new_tokens,
|
| 34 |
+
do_sample=False,
|
| 35 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
gen = outputs[0][inputs["input_ids"].shape[1]:]
|
| 39 |
+
return tokenizer.decode(gen, skip_special_tokens=True).strip()
|
src/metrics.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Metrics and scoring utilities."""
|
| 2 |
+
import re
|
| 3 |
+
import statistics
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def exact_match_score(prediction: str, target: str) -> float:
|
| 8 |
+
"""Binary exact-match score."""
|
| 9 |
+
return 1.0 if target.lower() in prediction.lower() else 0.0
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def numeric_match(prediction: str, target: float, tolerance: float = 0.5) -> float:
|
| 13 |
+
"""Extract first number from prediction and compare with tolerance."""
|
| 14 |
+
nums = re.findall(r"[\d,]+\.?\d*", prediction.replace(",", ""))
|
| 15 |
+
if not nums:
|
| 16 |
+
return 0.0
|
| 17 |
+
pred = float(nums[0])
|
| 18 |
+
return 1.0 if abs(pred - target) < tolerance else 0.0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def compute_accuracy(predictions: List[Dict[str, Any]], key: str = "correct") -> float:
|
| 22 |
+
"""Mean accuracy from a list of prediction records."""
|
| 23 |
+
vals = [p[key] for p in predictions]
|
| 24 |
+
return statistics.mean(vals) if vals else 0.0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def position_bias_index(positions: List[float], accuracies: List[float]) -> float:
|
| 28 |
+
"""
|
| 29 |
+
Compute Position Bias Index (PBI):
|
| 30 |
+
PBI = (acc_first + acc_last) / 2 - acc_middle
|
| 31 |
+
Higher PBI = stronger U-shape (worse).
|
| 32 |
+
"""
|
| 33 |
+
if len(positions) < 3:
|
| 34 |
+
return 0.0
|
| 35 |
+
mid_idx = len(positions) // 2
|
| 36 |
+
edge_acc = (accuracies[0] + accuracies[-1]) / 2.0
|
| 37 |
+
mid_acc = accuracies[mid_idx]
|
| 38 |
+
return edge_acc - mid_acc
|
src/model_loader.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model loading with 4-bit quantization for T4/GPU inference."""
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
_model_cache = {}
|
| 9 |
+
_tok_cache = {}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_model(model_name: str, load_in_4bit: bool = True, device_map: str = "auto"):
|
| 13 |
+
"""Load model with optional 4-bit quantization. Cached for reuse."""
|
| 14 |
+
cache_key = f"{model_name}:{load_in_4bit}:{device_map}"
|
| 15 |
+
if cache_key in _model_cache:
|
| 16 |
+
return _model_cache[cache_key], _tok_cache[cache_key]
|
| 17 |
+
|
| 18 |
+
logger.info(f"Loading model: {model_name}")
|
| 19 |
+
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 20 |
+
if tok.pad_token is None:
|
| 21 |
+
tok.pad_token = tok.eos_token
|
| 22 |
+
|
| 23 |
+
if load_in_4bit:
|
| 24 |
+
bnb = BitsAndBytesConfig(
|
| 25 |
+
load_in_4bit=True,
|
| 26 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 27 |
+
bnb_4bit_use_double_quant=True,
|
| 28 |
+
bnb_4bit_quant_type="nf4",
|
| 29 |
+
)
|
| 30 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 31 |
+
model_name,
|
| 32 |
+
quantization_config=bnb,
|
| 33 |
+
device_map=device_map,
|
| 34 |
+
trust_remote_code=True,
|
| 35 |
+
torch_dtype=torch.bfloat16,
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 39 |
+
model_name,
|
| 40 |
+
device_map=device_map,
|
| 41 |
+
trust_remote_code=True,
|
| 42 |
+
torch_dtype=torch.bfloat16,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
model.eval()
|
| 46 |
+
dev = next(model.parameters()).device
|
| 47 |
+
logger.info(f"Model loaded on {dev}")
|
| 48 |
+
|
| 49 |
+
_model_cache[cache_key] = model
|
| 50 |
+
_tok_cache[cache_key] = tok
|
| 51 |
+
return model, tok
|
src/plotting.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Plotting utilities for position-bias curves."""
|
| 2 |
+
import logging
|
| 3 |
+
import matplotlib
|
| 4 |
+
matplotlib.use("Agg")
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def plot_curve(
|
| 11 |
+
x_values,
|
| 12 |
+
y_values,
|
| 13 |
+
title: str,
|
| 14 |
+
save_path: str,
|
| 15 |
+
xlabel: str = "Position (0=start, 1=end)",
|
| 16 |
+
ylabel: str = "Accuracy",
|
| 17 |
+
ylim: tuple = (-0.05, 1.05),
|
| 18 |
+
color: str = "#E63946",
|
| 19 |
+
):
|
| 20 |
+
"""Plot a standard position-bias accuracy curve."""
|
| 21 |
+
plt.figure(figsize=(8, 5))
|
| 22 |
+
plt.plot(x_values, y_values, marker="o", linewidth=2.5, markersize=10, color=color)
|
| 23 |
+
plt.xlabel(xlabel, fontsize=13)
|
| 24 |
+
plt.ylabel(ylabel, fontsize=13)
|
| 25 |
+
plt.title(title, fontsize=13)
|
| 26 |
+
plt.ylim(ylim)
|
| 27 |
+
plt.grid(True, alpha=0.3)
|
| 28 |
+
plt.tight_layout()
|
| 29 |
+
plt.savefig(save_path, dpi=200)
|
| 30 |
+
plt.close()
|
| 31 |
+
logger.info(f"Plot saved: {save_path}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def plot_bar(categories, values, title: str, save_path: str, ylabel: str = "Accuracy", ylim=(0, 1.05), colors=None):
|
| 35 |
+
"""Plot a bar chart (e.g., for multi-needle start/middle/end)."""
|
| 36 |
+
if colors is None:
|
| 37 |
+
colors = ["#2E86AB", "#E63946", "#2E86AB"]
|
| 38 |
+
plt.figure(figsize=(6, 5))
|
| 39 |
+
plt.bar(categories, values, color=colors, edgecolor="black", linewidth=1.2)
|
| 40 |
+
plt.ylabel(ylabel, fontsize=13)
|
| 41 |
+
plt.title(title, fontsize=13)
|
| 42 |
+
plt.ylim(ylim)
|
| 43 |
+
plt.grid(True, alpha=0.3, axis="y")
|
| 44 |
+
plt.tight_layout()
|
| 45 |
+
plt.savefig(save_path, dpi=200)
|
| 46 |
+
plt.close()
|
| 47 |
+
logger.info(f"Bar plot saved: {save_path}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def plot_multi_curves(curves, labels, title, save_path, xlabel="Position", ylabel="Accuracy"):
|
| 51 |
+
"""Overlay multiple curves for comparison."""
|
| 52 |
+
plt.figure(figsize=(10, 6))
|
| 53 |
+
cmap = plt.get_cmap("tab10")
|
| 54 |
+
for i, (x, y, label) in enumerate(zip(curves["x"], curves["y"], labels)):
|
| 55 |
+
plt.plot(x, y, marker="o", linewidth=2.0, markersize=8, label=label, color=cmap(i))
|
| 56 |
+
plt.xlabel(xlabel, fontsize=13)
|
| 57 |
+
plt.ylabel(ylabel, fontsize=13)
|
| 58 |
+
plt.title(title, fontsize=13)
|
| 59 |
+
plt.ylim(-0.05, 1.05)
|
| 60 |
+
plt.legend()
|
| 61 |
+
plt.grid(True, alpha=0.3)
|
| 62 |
+
plt.tight_layout()
|
| 63 |
+
plt.savefig(save_path, dpi=200)
|
| 64 |
+
plt.close()
|
| 65 |
+
logger.info(f"Multi-curve plot saved: {save_path}")
|
src/utils.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Common utilities."""
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def ensure_dir(path: str):
|
| 11 |
+
"""Create directory if it doesn't exist."""
|
| 12 |
+
os.makedirs(path, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def save_jsonl(path: str, records: List[Dict[str, Any]]):
|
| 16 |
+
"""Save records as JSONL."""
|
| 17 |
+
with open(path, "w") as f:
|
| 18 |
+
for r in records:
|
| 19 |
+
f.write(json.dumps(r) + "\n")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_jsonl(path: str) -> List[Dict[str, Any]]:
|
| 23 |
+
"""Load JSONL records."""
|
| 24 |
+
records = []
|
| 25 |
+
with open(path) as f:
|
| 26 |
+
for line in f:
|
| 27 |
+
records.append(json.loads(line))
|
| 28 |
+
return records
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_json(path: str, data: Any):
|
| 32 |
+
"""Save data as pretty-printed JSON."""
|
| 33 |
+
with open(path, "w") as f:
|
| 34 |
+
json.dump(data, f, indent=2)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def load_json(path: str) -> Any:
|
| 38 |
+
"""Load JSON file."""
|
| 39 |
+
with open(path) as f:
|
| 40 |
+
return json.load(f)
|