litm-benchmark-suite-v4 / experiments /needle_in_haystack.py
abhshkp's picture
Upload experiments/needle_in_haystack.py
6da1df8 verified
"""
Experiment 2: Needle in Haystack (text)
Tests retrieval of a fact hidden at varying depths in filler text.
FIXED: 2000-sentence haystacks + entity-overlap distractors to prevent keyword-only
retrieval. The target entities (person, item, location) each appear in multiple
sentences, forcing the model to attend to the right combination.
"""
import logging
import os
import random
import time
from typing import List, Dict, Any
from tqdm import tqdm
from src.generator import generate_text
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
from src.plotting import plot_curve
from src.utils import ensure_dir, save_jsonl, save_json
logger = logging.getLogger(__name__)
# Generic filler sentences
FILLERS = [
"The history of pottery spans thousands of years.",
"Marine biologists study coral reef ecosystems.",
"Railway engineering requires precise curvature calculations.",
"The periodic table arranges elements by atomic number.",
"Clouds are classified into cumulus and stratus types.",
"Beekeeping traditions differ significantly between continents.",
"The Great Wall was constructed over many successive dynasties.",
"Thermodynamics governs the principles of heat transfer.",
"Impressionist painters captured fleeting effects of light.",
"Volcanic activity is closely tracked with seismographs.",
"The Dewey Decimal System organizes library collections worldwide.",
"Irrigation technology evolved from canals to drip systems.",
"Neural networks are directly inspired by biological brains.",
"Light speed in vacuum is 299,792,458 meters per second.",
"Classical composition generally follows established harmonic rules.",
"Urban planning must address zoning and public transport.",
"Photosynthesis converts carbon dioxide into glucose and oxygen.",
"The Fibonacci sequence appears frequently throughout nature.",
"GPS navigation uses triangulation from orbiting satellites.",
"Cryptography secures modern digital communications against eavesdropping.",
]
# Entities that appear in MULTIPLE sentences — no single entity is unique
NAMES = ["Alice", "Bob", "Carol", "David", "Eve", "Frank", "Grace", "Heidi"]
ITEMS = ["bicycle", "laptop", "watch", "camera", "guitar", "sneakers", "backpack", "headphones"]
PLACES = ["downtown shop", "uptown store", "westside mall", "eastside market", "riverside plaza"]
def _make_entity_distractor(person: str, item: str, place: str, used: set) -> str:
"""Create a distractor sentence sharing 1-2 entities with the target but not all 3."""
templates = [
"{person} visited the {place} last Tuesday to browse items.",
"The {place} sells various products including {item}s and accessories.",
"{person} enjoys using their {item} during weekend activities.",
"A customer purchased a {item} from the {place} earlier this month.",
"{person} recommended the {place} to friends and family members.",
"The {place} had a promotional sale on {item}s last holiday season.",
"{person} previously owned a different {item} before upgrading.",
"Shoppers at the {place} often look for quality {item}s.",
]
# Pick a template and substitute with random entities (may overlap)
tmpl = random.choice(templates)
p = random.choice(NAMES)
it = random.choice(ITEMS)
pl = random.choice(PLACES)
sent = tmpl.format(person=p, item=it, place=pl)
# Ensure it shares at least one entity with the target (person, item, place)
# but not all three (otherwise it's a duplicate target)
if p == person and it == item and pl == place:
# Swap one entity to avoid being identical to target
swap = random.choice(["person", "item", "place"])
if swap == "person":
p = random.choice([n for n in NAMES if n != person])
elif swap == "item":
it = random.choice([i for i in ITEMS if i != item])
else:
pl = random.choice([pl for pl in PLACES if pl != place])
sent = tmpl.format(person=p, item=it, place=pl)
return sent
def _make_haystack(n: int, target_person: str, target_item: str, target_place: str, num_distractors: int = 40) -> str:
"""Generate n sentences with entity-overlap distractors scattered throughout."""
sents = []
# Add distractor sentences that share entities
for _ in range(num_distractors):
sents.append(_make_entity_distractor(target_person, target_item, target_place, set()))
# Fill remaining with generic fillers
while len(sents) < n:
sents.append(random.choice(FILLERS))
random.shuffle(sents)
return " ".join(sents)
def _insert_needle(text: str, needle: str, ratio: float) -> str:
"""Insert needle at specified depth ratio."""
sents = [s.strip() + "." for s in text.split(".") if s.strip()]
idx = int(ratio * len(sents))
sents.insert(idx, needle)
return " ".join(sents)
def run_needle_in_haystack(
model_name: str,
num_sentences: int,
num_examples: int,
out_dir: str,
depths: List[float] = None,
) -> Dict[str, Any]:
"""Run needle-in-haystack with entity-overlap distractors."""
ensure_dir(out_dir)
if depths is None:
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
results = {}
start = time.time()
for depth in depths:
logger.info(f"[NEEDLE] Depth {depth:.1%}")
preds = []
for i in tqdm(range(num_examples), desc=f"Needle {depth:.1%}", leave=False):
# Choose target entities
person = random.choice(NAMES)
item = random.choice(ITEMS)
place = random.choice(PLACES)
price = random.randint(100, 999)
# Build haystack with entity-overlap distractors
filler = _make_haystack(num_sentences, person, item, place, num_distractors=40)
# Target sentence (the needle)
needle = f"{person} purchased a {item} from the {place} for ${price}."
text = _insert_needle(filler, needle, depth)
# Question forces the model to find the RIGHT combination, not just any mention
prompt = (
f"Read the passage carefully. {person} is mentioned several times, "
f"and the {place} is mentioned several times, and {item}s are mentioned several times. "
f"Find the specific sentence that says how much {person} paid for a {item} at the {place}. "
f"Answer with only the dollar amount (no $ sign, no words)."
)
ans = generate_text(
[{"role": "user", "content": prompt}],
model_name=model_name,
max_new_tokens=10,
)
correct = exact_match_score(ans, str(price))
preds.append({
"model_answer": ans,
"correct": correct,
"expected": price,
"depth": depth,
})
save_jsonl(os.path.join(out_dir, f"needle_depth_{depth}.jsonl"), preds)
acc = compute_accuracy(preds)
results[depth] = {"accuracy": acc, "predictions": preds}
logger.info(f"[NEEDLE] Depth {depth:.1%}: acc={acc:.3f}")
summary = {
"experiment": "needle_in_haystack",
"num_sentences": num_sentences,
"num_examples": num_examples,
"depths": {str(d): results[d]["accuracy"] for d in depths},
"pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
"time_minutes": (time.time() - start) / 60,
}
save_json(os.path.join(out_dir, "needle_summary.json"), summary)
plot_curve(
depths,
[results[d]["accuracy"] for d in depths],
f"Exp 2: Needle in Haystack ({num_sentences} sentences)",
os.path.join(out_dir, "needle_curve.png"),
xlabel="Depth in Document (0=start, 1=end)",
)
logger.info(f"[NEEDLE] Time={(time.time()-start)/60:.1f} min")
return summary