abhshkp commited on
Commit
6da1df8
·
verified ·
1 Parent(s): 72a171f

Upload experiments/needle_in_haystack.py

Browse files
Files changed (1) hide show
  1. experiments/needle_in_haystack.py +93 -67
experiments/needle_in_haystack.py CHANGED
@@ -1,8 +1,9 @@
1
  """
2
  Experiment 2: Needle in Haystack (text)
3
- Tests retrieval of a secret code hidden at varying depths in filler text.
4
- FIXED: Increased default context to 1500 sentences to stress 1.5B model attention.
5
- Also adds decoy codes to prevent trivial keyword-only retrieval.
 
6
  """
7
  import logging
8
  import os
@@ -19,66 +20,81 @@ from src.utils import ensure_dir, save_jsonl, save_json
19
 
20
  logger = logging.getLogger(__name__)
21
 
 
22
  FILLERS = [
23
  "The history of pottery spans thousands of years.",
24
  "Marine biologists study coral reef ecosystems.",
25
- "Railway engineering requires precise curvature.",
26
- "The periodic table arranges elements by number.",
27
- "Clouds are classified as cumulus and stratus.",
28
- "Beekeeping traditions differ between continents.",
29
- "The Great Wall was built over many dynasties.",
30
- "Thermodynamics governs heat transfer.",
31
- "Impressionist painters captured fleeting light.",
32
- "Volcanic activity is tracked with seismographs.",
33
- "The Dewey Decimal System organizes libraries.",
34
- "Irrigation evolved from canals to drip systems.",
35
- "Neural networks are inspired by biological brains.",
36
- "Light speed is 299,792,458 meters per second.",
37
- "Classical composition follows harmonic rules.",
38
- "Urban planning addresses zoning and transport.",
39
- "Photosynthesis converts CO2 into glucose.",
40
- "The Fibonacci sequence appears in nature.
41
- "GPS uses triangulation from satellites.",
42
- "Cryptography secures digital communication.",
43
- "Aerodynamics explains lift on aircraft wings.",
44
- "Meteorologists track pressure systems globally.",
45
- "The Rosetta Stone enabled hieroglyph translation.",
46
- "Bioluminescence occurs in deep ocean species.",
47
- "Microprocessors revolutionized personal computing.",
48
- "Tectonic plates shift gradually over millennia.",
49
- "The printing press spread knowledge across Europe.",
50
- "Quantum entanglement defies classical intuition.",
51
- "Archaeologists excavate buried ancient settlements.",
52
- "Photosensors detect minute light variations.",
53
  ]
54
 
55
- # Decoy sentences with codes that use similar patterns
56
- DECOY_TEMPLATES = [
57
- "The transaction was logged with code TX-{code}.",
58
- "The batch identifier is BC-{code}.",
59
- "Session recorded under SC-{code}.",
60
- "Access granted via AC-{code}.",
61
- "Error logged as EC-{code}.",
62
- "Debug trace shows DC-{code}.",
63
- "Network packet tagged NC-{code}.",
64
- "User authenticated with UC-{code}.",
65
- "System heartbeat code SY-{code}.",
66
- "Database query ID DB-{code}.",
67
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
- def _make_haystack(n: int, num_decoys: int = 15) -> str:
71
- """Generate n sentences of filler text with decoy codes scattered throughout."""
72
  sents = []
73
- for i in range(n):
74
- if random.random() < (num_decoys / n) and num_decoys > 0:
75
- # Insert a decoy sentence instead of filler
76
- template = random.choice(DECOY_TEMPLATES)
77
- code = f"{random.randint(1000, 9999)}"
78
- sents.append(template.format(code=code))
79
- num_decoys -= 1
80
- else:
81
- sents.append(random.choice(FILLERS) + f" [{i+1}].")
82
  random.shuffle(sents)
83
  return " ".join(sents)
84
 
@@ -98,7 +114,7 @@ def run_needle_in_haystack(
98
  out_dir: str,
99
  depths: List[float] = None,
100
  ) -> Dict[str, Any]:
101
- """Run needle-in-haystack experiment with decoy codes."""
102
  ensure_dir(out_dir)
103
 
104
  if depths is None:
@@ -111,27 +127,37 @@ def run_needle_in_haystack(
111
  logger.info(f"[NEEDLE] Depth {depth:.1%}")
112
  preds = []
113
  for i in tqdm(range(num_examples), desc=f"Needle {depth:.1%}", leave=False):
114
- filler = _make_haystack(num_sentences, num_decoys=15)
115
- code = f"{random.randint(10000, 99999)}"
116
- # Target uses a different prefix than decoys to make it distinguishable
117
- # but the model must find it among many codes
118
- needle = f"The classified identifier is CL-{code}."
 
 
 
 
 
 
119
  text = _insert_needle(filler, needle, depth)
 
 
120
  prompt = (
121
- f"Read the text carefully. Multiple codes appear, but only one is "
122
- f"the classified identifier. Find it.\n\n{text}\n\n"
123
- f"What is the classified identifier? Answer with only the full code (including CL- prefix)."
 
124
  )
 
125
  ans = generate_text(
126
  [{"role": "user", "content": prompt}],
127
  model_name=model_name,
128
- max_new_tokens=20,
129
  )
130
- correct = exact_match_score(ans, f"CL-{code}")
131
  preds.append({
132
  "model_answer": ans,
133
  "correct": correct,
134
- "secret": f"CL-{code}",
135
  "depth": depth,
136
  })
137
 
 
1
  """
2
  Experiment 2: Needle in Haystack (text)
3
+ Tests retrieval of a fact hidden at varying depths in filler text.
4
+ FIXED: 2000-sentence haystacks + entity-overlap distractors to prevent keyword-only
5
+ retrieval. The target entities (person, item, location) each appear in multiple
6
+ sentences, forcing the model to attend to the right combination.
7
  """
8
  import logging
9
  import os
 
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
+ # Generic filler sentences
24
  FILLERS = [
25
  "The history of pottery spans thousands of years.",
26
  "Marine biologists study coral reef ecosystems.",
27
+ "Railway engineering requires precise curvature calculations.",
28
+ "The periodic table arranges elements by atomic number.",
29
+ "Clouds are classified into cumulus and stratus types.",
30
+ "Beekeeping traditions differ significantly between continents.",
31
+ "The Great Wall was constructed over many successive dynasties.",
32
+ "Thermodynamics governs the principles of heat transfer.",
33
+ "Impressionist painters captured fleeting effects of light.",
34
+ "Volcanic activity is closely tracked with seismographs.",
35
+ "The Dewey Decimal System organizes library collections worldwide.",
36
+ "Irrigation technology evolved from canals to drip systems.",
37
+ "Neural networks are directly inspired by biological brains.",
38
+ "Light speed in vacuum is 299,792,458 meters per second.",
39
+ "Classical composition generally follows established harmonic rules.",
40
+ "Urban planning must address zoning and public transport.",
41
+ "Photosynthesis converts carbon dioxide into glucose and oxygen.",
42
+ "The Fibonacci sequence appears frequently throughout nature.",
43
+ "GPS navigation uses triangulation from orbiting satellites.",
44
+ "Cryptography secures modern digital communications against eavesdropping.",
 
 
 
 
 
 
 
 
 
 
45
  ]
46
 
47
+ # Entities that appear in MULTIPLE sentences no single entity is unique
48
+ NAMES = ["Alice", "Bob", "Carol", "David", "Eve", "Frank", "Grace", "Heidi"]
49
+ ITEMS = ["bicycle", "laptop", "watch", "camera", "guitar", "sneakers", "backpack", "headphones"]
50
+ PLACES = ["downtown shop", "uptown store", "westside mall", "eastside market", "riverside plaza"]
51
+
52
+
53
+ def _make_entity_distractor(person: str, item: str, place: str, used: set) -> str:
54
+ """Create a distractor sentence sharing 1-2 entities with the target but not all 3."""
55
+ templates = [
56
+ "{person} visited the {place} last Tuesday to browse items.",
57
+ "The {place} sells various products including {item}s and accessories.",
58
+ "{person} enjoys using their {item} during weekend activities.",
59
+ "A customer purchased a {item} from the {place} earlier this month.",
60
+ "{person} recommended the {place} to friends and family members.",
61
+ "The {place} had a promotional sale on {item}s last holiday season.",
62
+ "{person} previously owned a different {item} before upgrading.",
63
+ "Shoppers at the {place} often look for quality {item}s.",
64
+ ]
65
+ # Pick a template and substitute with random entities (may overlap)
66
+ tmpl = random.choice(templates)
67
+ p = random.choice(NAMES)
68
+ it = random.choice(ITEMS)
69
+ pl = random.choice(PLACES)
70
+ sent = tmpl.format(person=p, item=it, place=pl)
71
+ # Ensure it shares at least one entity with the target (person, item, place)
72
+ # but not all three (otherwise it's a duplicate target)
73
+ if p == person and it == item and pl == place:
74
+ # Swap one entity to avoid being identical to target
75
+ swap = random.choice(["person", "item", "place"])
76
+ if swap == "person":
77
+ p = random.choice([n for n in NAMES if n != person])
78
+ elif swap == "item":
79
+ it = random.choice([i for i in ITEMS if i != item])
80
+ else:
81
+ pl = random.choice([pl for pl in PLACES if pl != place])
82
+ sent = tmpl.format(person=p, item=it, place=pl)
83
+ return sent
84
 
85
 
86
+ def _make_haystack(n: int, target_person: str, target_item: str, target_place: str, num_distractors: int = 40) -> str:
87
+ """Generate n sentences with entity-overlap distractors scattered throughout."""
88
  sents = []
89
+
90
+ # Add distractor sentences that share entities
91
+ for _ in range(num_distractors):
92
+ sents.append(_make_entity_distractor(target_person, target_item, target_place, set()))
93
+
94
+ # Fill remaining with generic fillers
95
+ while len(sents) < n:
96
+ sents.append(random.choice(FILLERS))
97
+
98
  random.shuffle(sents)
99
  return " ".join(sents)
100
 
 
114
  out_dir: str,
115
  depths: List[float] = None,
116
  ) -> Dict[str, Any]:
117
+ """Run needle-in-haystack with entity-overlap distractors."""
118
  ensure_dir(out_dir)
119
 
120
  if depths is None:
 
127
  logger.info(f"[NEEDLE] Depth {depth:.1%}")
128
  preds = []
129
  for i in tqdm(range(num_examples), desc=f"Needle {depth:.1%}", leave=False):
130
+ # Choose target entities
131
+ person = random.choice(NAMES)
132
+ item = random.choice(ITEMS)
133
+ place = random.choice(PLACES)
134
+ price = random.randint(100, 999)
135
+
136
+ # Build haystack with entity-overlap distractors
137
+ filler = _make_haystack(num_sentences, person, item, place, num_distractors=40)
138
+
139
+ # Target sentence (the needle)
140
+ needle = f"{person} purchased a {item} from the {place} for ${price}."
141
  text = _insert_needle(filler, needle, depth)
142
+
143
+ # Question forces the model to find the RIGHT combination, not just any mention
144
  prompt = (
145
+ f"Read the passage carefully. {person} is mentioned several times, "
146
+ f"and the {place} is mentioned several times, and {item}s are mentioned several times. "
147
+ f"Find the specific sentence that says how much {person} paid for a {item} at the {place}. "
148
+ f"Answer with only the dollar amount (no $ sign, no words)."
149
  )
150
+
151
  ans = generate_text(
152
  [{"role": "user", "content": prompt}],
153
  model_name=model_name,
154
+ max_new_tokens=10,
155
  )
156
+ correct = exact_match_score(ans, str(price))
157
  preds.append({
158
  "model_answer": ans,
159
  "correct": correct,
160
+ "expected": price,
161
  "depth": depth,
162
  })
163