""" Experiment 3: Multi-Needle Retrieval Tests ability to retrieve ALL of multiple needles placed at start, middle, and end. """ import logging import os import random import time from typing import Dict, Any from tqdm import tqdm from src.generator import generate_text from src.metrics import exact_match_score, compute_accuracy from src.plotting import plot_bar from src.utils import ensure_dir, save_json logger = logging.getLogger(__name__) from .needle_in_haystack import FILLERS def _make_haystack(n: int) -> str: return " ".join(random.choice(FILLERS) + f" [{i+1}]." for i in range(n)) def run_multi_needle( model_name: str, num_sentences: int, num_examples: int, out_dir: str, ) -> Dict[str, Any]: """Run multi-needle experiment.""" ensure_dir(out_dir) start = time.time() start_ok, mid_ok, end_ok = [], [], [] for i in tqdm(range(num_examples), desc="Multi-needle"): filler = _make_haystack(num_sentences) sents = [s.strip() + "." for s in filler.split(".") if s.strip()] n = len(sents) ca, cb, cc = f"ALPHA-{i:03d}", f"BETA-{i:03d}", f"GAMMA-{i:03d}" sents.insert(0, f"The first secret code is {ca}.") sents.insert(n // 2, f"The second secret code is {cb}.") sents.append(f"The third secret code is {cc}.") prompt = ( f"Read the text and list ALL three secret codes in order.\n\n" f"{' '.join(sents)}\n\nCodes:" ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=60, ) start_ok.append(exact_match_score(ans, ca)) mid_ok.append(exact_match_score(ans, cb)) end_ok.append(exact_match_score(ans, cc)) summary = { "experiment": "multi_needle", "num_sentences": num_sentences, "num_examples": num_examples, "start": compute_accuracy([{"correct": c} for c in start_ok]), "middle": compute_accuracy([{"correct": c} for c in mid_ok]), "end": compute_accuracy([{"correct": c} for c in end_ok]), "time_minutes": (time.time() - start) / 60, } logger.info( f"[MULTI] Start={summary['start']:.3f} Mid={summary['middle']:.3f} End={summary['end']:.3f}" ) save_json(os.path.join(out_dir, "multi_summary.json"), summary) plot_bar( ["Start", "Middle", "End"], [summary["start"], summary["middle"], summary["end"]], f"Exp 3: Multi-Needle (n={num_examples})", os.path.join(out_dir, "multi_bar.png"), ) return summary