File size: 2,612 Bytes
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Experiment 3: Multi-Needle Retrieval
Tests ability to retrieve ALL of multiple needles placed at start, middle, and end.
"""
import logging
import os
import random
import time
from typing import Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.metrics import exact_match_score, compute_accuracy
from src.plotting import plot_bar
from src.utils import ensure_dir, save_json

logger = logging.getLogger(__name__)

from .needle_in_haystack import FILLERS


def _make_haystack(n: int) -> str:
    return " ".join(random.choice(FILLERS) + f" [{i+1}]." for i in range(n))


def run_multi_needle(
    model_name: str,
    num_sentences: int,
    num_examples: int,
    out_dir: str,
) -> Dict[str, Any]:
    """Run multi-needle experiment."""
    ensure_dir(out_dir)

    start = time.time()
    start_ok, mid_ok, end_ok = [], [], []

    for i in tqdm(range(num_examples), desc="Multi-needle"):
        filler = _make_haystack(num_sentences)
        sents = [s.strip() + "." for s in filler.split(".") if s.strip()]
        n = len(sents)
        ca, cb, cc = f"ALPHA-{i:03d}", f"BETA-{i:03d}", f"GAMMA-{i:03d}"

        sents.insert(0, f"The first secret code is {ca}.")
        sents.insert(n // 2, f"The second secret code is {cb}.")
        sents.append(f"The third secret code is {cc}.")

        prompt = (
            f"Read the text and list ALL three secret codes in order.\n\n"
            f"{' '.join(sents)}\n\nCodes:"
        )
        ans = generate_text(
            [{"role": "user", "content": prompt}],
            model_name=model_name,
            max_new_tokens=60,
        )
        start_ok.append(exact_match_score(ans, ca))
        mid_ok.append(exact_match_score(ans, cb))
        end_ok.append(exact_match_score(ans, cc))

    summary = {
        "experiment": "multi_needle",
        "num_sentences": num_sentences,
        "num_examples": num_examples,
        "start": compute_accuracy([{"correct": c} for c in start_ok]),
        "middle": compute_accuracy([{"correct": c} for c in mid_ok]),
        "end": compute_accuracy([{"correct": c} for c in end_ok]),
        "time_minutes": (time.time() - start) / 60,
    }

    logger.info(
        f"[MULTI] Start={summary['start']:.3f} Mid={summary['middle']:.3f} End={summary['end']:.3f}"
    )

    save_json(os.path.join(out_dir, "multi_summary.json"), summary)
    plot_bar(
        ["Start", "Middle", "End"],
        [summary["start"], summary["middle"], summary["end"]],
        f"Exp 3: Multi-Needle (n={num_examples})",
        os.path.join(out_dir, "multi_bar.png"),
    )

    return summary