""" Bloom's Level: Remember Simple factual recall from long context. """ import logging import os import random import time from typing import List, Dict, Any from tqdm import tqdm from src.generator import generate_text from src.metrics import exact_match_score, compute_accuracy from src.utils import ensure_dir, save_jsonl, save_json logger = logging.getLogger(__name__) FILLERS = [ "The museum houses artifacts from the ancient world.", "Coral reefs support diverse marine ecosystems.", "Railway gauges vary between countries.", "The periodic table organizes elements systematically.", "Cloud formation depends on atmospheric pressure.", ] def run_remember( model_name: str, num_sentences: int, num_examples: int, out_dir: str, depths: List[float] = None, ) -> Dict[str, Any]: ensure_dir(out_dir) if depths is None: depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] results = {} start = time.time() for depth in depths: logger.info(f"[REMEMBER] Depth {depth:.1%}") preds = [] for _ in tqdm(range(num_examples), desc=f"Remember {depth:.1%}", leave=False): sents = [random.choice(FILLERS) for _ in range(num_sentences)] secret = f"FACT-{random.randint(1000, 9999)}" fact = f"The critical fact is: {secret}." idx = int(depth * len(sents)) sents.insert(idx, fact) doc = " ".join(sents) prompt = f"Read the text and extract the critical fact.\n\n{doc}\n\nCritical fact:" ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=20, ) correct = exact_match_score(ans, secret) preds.append({ "model_answer": ans, "correct": correct, "secret": secret, "depth": depth, }) save_jsonl(os.path.join(out_dir, f"remember_depth_{depth}.jsonl"), preds) acc = compute_accuracy(preds) results[depth] = {"accuracy": acc, "predictions": preds} logger.info(f"[REMEMBER] Depth {depth:.1%}: acc={acc:.3f}") summary = { "experiment": "remember", "cognitive_level": "remember", "num_sentences": num_sentences, "num_examples": num_examples, "depths": {str(d): results[d]["accuracy"] for d in depths}, "time_minutes": (time.time() - start) / 60, } save_json(os.path.join(out_dir, "remember_summary.json"), summary) logger.info(f"[REMEMBER] Time={(time.time()-start)/60:.1f} min") return summary