File size: 8,093 Bytes
e1e1ce9
 
6da1df8
 
 
 
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6da1df8
e1e1ce9
 
 
6da1df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72a171f
 
6da1df8
 
 
 
e1e1ce9
 
6da1df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72a171f
6da1df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72a171f
 
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6da1df8
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
6da1df8
 
 
 
 
 
 
 
 
 
 
e1e1ce9
6da1df8
 
e1e1ce9
6da1df8
 
 
 
e1e1ce9
6da1df8
e1e1ce9
 
 
6da1df8
e1e1ce9
6da1df8
e1e1ce9
 
 
6da1df8
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Experiment 2: Needle in Haystack (text)
Tests retrieval of a fact hidden at varying depths in filler text.
FIXED: 2000-sentence haystacks + entity-overlap distractors to prevent keyword-only
retrieval. The target entities (person, item, location) each appear in multiple
sentences, forcing the model to attend to the right combination.
"""
import logging
import os
import random
import time
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
from src.plotting import plot_curve
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)

# Generic filler sentences
FILLERS = [
    "The history of pottery spans thousands of years.",
    "Marine biologists study coral reef ecosystems.",
    "Railway engineering requires precise curvature calculations.",
    "The periodic table arranges elements by atomic number.",
    "Clouds are classified into cumulus and stratus types.",
    "Beekeeping traditions differ significantly between continents.",
    "The Great Wall was constructed over many successive dynasties.",
    "Thermodynamics governs the principles of heat transfer.",
    "Impressionist painters captured fleeting effects of light.",
    "Volcanic activity is closely tracked with seismographs.",
    "The Dewey Decimal System organizes library collections worldwide.",
    "Irrigation technology evolved from canals to drip systems.",
    "Neural networks are directly inspired by biological brains.",
    "Light speed in vacuum is 299,792,458 meters per second.",
    "Classical composition generally follows established harmonic rules.",
    "Urban planning must address zoning and public transport.",
    "Photosynthesis converts carbon dioxide into glucose and oxygen.",
    "The Fibonacci sequence appears frequently throughout nature.",
    "GPS navigation uses triangulation from orbiting satellites.",
    "Cryptography secures modern digital communications against eavesdropping.",
]

# Entities that appear in MULTIPLE sentences — no single entity is unique
NAMES = ["Alice", "Bob", "Carol", "David", "Eve", "Frank", "Grace", "Heidi"]
ITEMS = ["bicycle", "laptop", "watch", "camera", "guitar", "sneakers", "backpack", "headphones"]
PLACES = ["downtown shop", "uptown store", "westside mall", "eastside market", "riverside plaza"]


def _make_entity_distractor(person: str, item: str, place: str, used: set) -> str:
    """Create a distractor sentence sharing 1-2 entities with the target but not all 3."""
    templates = [
        "{person} visited the {place} last Tuesday to browse items.",
        "The {place} sells various products including {item}s and accessories.",
        "{person} enjoys using their {item} during weekend activities.",
        "A customer purchased a {item} from the {place} earlier this month.",
        "{person} recommended the {place} to friends and family members.",
        "The {place} had a promotional sale on {item}s last holiday season.",
        "{person} previously owned a different {item} before upgrading.",
        "Shoppers at the {place} often look for quality {item}s.",
    ]
    # Pick a template and substitute with random entities (may overlap)
    tmpl = random.choice(templates)
    p = random.choice(NAMES)
    it = random.choice(ITEMS)
    pl = random.choice(PLACES)
    sent = tmpl.format(person=p, item=it, place=pl)
    # Ensure it shares at least one entity with the target (person, item, place)
    # but not all three (otherwise it's a duplicate target)
    if p == person and it == item and pl == place:
        # Swap one entity to avoid being identical to target
        swap = random.choice(["person", "item", "place"])
        if swap == "person":
            p = random.choice([n for n in NAMES if n != person])
        elif swap == "item":
            it = random.choice([i for i in ITEMS if i != item])
        else:
            pl = random.choice([pl for pl in PLACES if pl != place])
        sent = tmpl.format(person=p, item=it, place=pl)
    return sent


def _make_haystack(n: int, target_person: str, target_item: str, target_place: str, num_distractors: int = 40) -> str:
    """Generate n sentences with entity-overlap distractors scattered throughout."""
    sents = []
    
    # Add distractor sentences that share entities
    for _ in range(num_distractors):
        sents.append(_make_entity_distractor(target_person, target_item, target_place, set()))
    
    # Fill remaining with generic fillers
    while len(sents) < n:
        sents.append(random.choice(FILLERS))
    
    random.shuffle(sents)
    return " ".join(sents)


def _insert_needle(text: str, needle: str, ratio: float) -> str:
    """Insert needle at specified depth ratio."""
    sents = [s.strip() + "." for s in text.split(".") if s.strip()]
    idx = int(ratio * len(sents))
    sents.insert(idx, needle)
    return " ".join(sents)


def run_needle_in_haystack(
    model_name: str,
    num_sentences: int,
    num_examples: int,
    out_dir: str,
    depths: List[float] = None,
) -> Dict[str, Any]:
    """Run needle-in-haystack with entity-overlap distractors."""
    ensure_dir(out_dir)

    if depths is None:
        depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[NEEDLE] Depth {depth:.1%}")
        preds = []
        for i in tqdm(range(num_examples), desc=f"Needle {depth:.1%}", leave=False):
            # Choose target entities
            person = random.choice(NAMES)
            item = random.choice(ITEMS)
            place = random.choice(PLACES)
            price = random.randint(100, 999)
            
            # Build haystack with entity-overlap distractors
            filler = _make_haystack(num_sentences, person, item, place, num_distractors=40)
            
            # Target sentence (the needle)
            needle = f"{person} purchased a {item} from the {place} for ${price}."
            text = _insert_needle(filler, needle, depth)
            
            # Question forces the model to find the RIGHT combination, not just any mention
            prompt = (
                f"Read the passage carefully. {person} is mentioned several times, "
                f"and the {place} is mentioned several times, and {item}s are mentioned several times. "
                f"Find the specific sentence that says how much {person} paid for a {item} at the {place}. "
                f"Answer with only the dollar amount (no $ sign, no words)."
            )
            
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=10,
            )
            correct = exact_match_score(ans, str(price))
            preds.append({
                "model_answer": ans,
                "correct": correct,
                "expected": price,
                "depth": depth,
            })

        save_jsonl(os.path.join(out_dir, f"needle_depth_{depth}.jsonl"), preds)
        acc = compute_accuracy(preds)
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[NEEDLE] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": "needle_in_haystack",
        "num_sentences": num_sentences,
        "num_examples": num_examples,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
        "time_minutes": (time.time() - start) / 60,
    }

    save_json(os.path.join(out_dir, "needle_summary.json"), summary)
    plot_curve(
        depths,
        [results[d]["accuracy"] for d in depths],
        f"Exp 2: Needle in Haystack ({num_sentences} sentences)",
        os.path.join(out_dir, "needle_curve.png"),
        xlabel="Depth in Document (0=start, 1=end)",
    )

    logger.info(f"[NEEDLE] Time={(time.time()-start)/60:.1f} min")
    return summary