""" Experiment 7: Conversation Memory Critical instruction buried in long chat history. FIXED: 300-turn conversations with decoy instructions to prevent trivial keyword search. Multiple "remember" instructions appear; the model must find the correct one. """ import logging import os import random import time from typing import List, Dict, Any from tqdm import tqdm from src.generator import generate_text from src.metrics import exact_match_score, compute_accuracy, position_bias_index from src.plotting import plot_curve from src.utils import ensure_dir, save_jsonl, save_json logger = logging.getLogger(__name__) USER_MSGS = [ "Hello, how are you?", "What is the weather like today?", "Tell me about quantum physics.", "Can you recommend a good book?", "What are the health benefits of green tea?", "Explain how airplanes fly.", "What is the history of the internet?", "How do I bake sourdough bread?", "What are the best hiking trails in Europe?", "Explain neural networks simply.", "What is blockchain technology?", "How does photosynthesis work?", "Tell me a joke.", "What is the theory of relativity?", "How do vaccines work?", "What causes earthquakes?", "Explain the water cycle.", "What is artificial intelligence?", "How do I learn a new language?", "What are black holes?", "What is dark matter?", "How do electric cars work?", "What is machine learning?", "Explain DNA replication.", "How does 3D printing work?", ] ASSISTANT_MSGS = [ "I'm doing well, thank you!", "The weather varies by location and season.", "Quantum physics studies matter at the smallest scales.", "I recommend 'Sapiens' by Yuval Noah Harari.", "Green tea contains antioxidants that may boost metabolism.", "Airplanes fly due to lift generated by their wings.", "The internet evolved from ARPANET in the 1960s.", "Sourdough requires flour, water, salt, and a starter culture.", "The Tour du Mont Blanc is a spectacular alpine trail.", "Neural networks learn patterns from data through layers.", "Blockchain is a decentralized digital ledger.", "Plants convert sunlight into chemical energy.", "Why don't scientists trust atoms? Because they make up everything!", "Relativity describes how space and time are interconnected.", "Vaccines train the immune system to recognize pathogens.", "Earthquakes occur when tectonic plates shift.", "Water evaporates, condenses, and precipitates in a cycle.", "AI enables machines to perform tasks requiring human intelligence.", "Practice daily, immerse yourself, and use spaced repetition.", "Black holes have gravitational fields so strong nothing escapes.", "Dark matter does not emit light but exerts gravitational pull.", "Electric cars use battery packs to power electric motors.", "Machine learning is a subset of AI focused on pattern recognition.", "DNA replication ensures genetic information is copied accurately.", "3D printing builds objects layer by layer from digital models.", ] # Decoy instructions that use the same pattern but different values DECOY_INSTRUCTIONS = [ "Please remember my backup phone number is {}. This is important.", "Please remember my work ID is {}. This is important.", "Please remember my emergency contact is {}. This is important.", "Please remember my security question answer is {}. This is important.", "Please remember my alternate email is {}. This is important.", "Please remember my locker combination is {}. This is important.", "Please remember my membership ID is {}. This is important.", "Please remember my API key is {}. This is important.", "Please remember my passport number is {}. This is important.", "Please remember my account PIN is {}. This is important.", ] def _make_conversation(num_turns: int, target_instruction: str, ratio: float, num_decoys: int = 8) -> str: """Build conversation with decoy instructions and target at specified depth.""" convo = [] # Generate base turns for i in range(num_turns): convo.append(f"User: {random.choice(USER_MSGS)}") convo.append(f"Assistant: {random.choice(ASSISTANT_MSGS)}") # Insert decoy instructions at random positions (avoid target position) target_idx = int(ratio * len(convo)) decoy_positions = set() for _ in range(num_decoys): pos = random.randint(0, len(convo) - 1) # Don't place decoy right at target position if abs(pos - target_idx) > 5: decoy_positions.add(pos) decoy_values = [] for pos in sorted(decoy_positions): tmpl = random.choice(DECOY_INSTRUCTIONS) val = f"DECOY-{random.randint(1000, 9999)}" decoy_values.append(val) convo.insert(pos, f"User: {tmpl.format(val)}") convo.insert(pos + 1, "Assistant: I will remember that.") # Insert target instruction convo.insert(target_idx, f"User: {target_instruction}") convo.insert(target_idx + 1, "Assistant: I will remember that.") return "\n\n".join(convo) def run_conversation_memory( model_name: str, num_turns: int, num_examples: int, out_dir: str, depths: List[float] = None, ) -> Dict[str, Any]: """Run conversation memory experiment with decoy instructions.""" ensure_dir(out_dir) if depths is None: depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] results = {} start = time.time() for depth in depths: logger.info(f"[CONVERSATION] Depth {depth:.1%}") preds = [] for i in tqdm(range(num_examples), desc=f"Conversation {depth:.1%}", leave=False): secret = f"SEC-{random.randint(10000, 99999)}" target_instruction = ( f"Please always remember that my access code is {secret}. " f"This is extremely important for future questions." ) convo = _make_conversation(num_turns, target_instruction, depth, num_decoys=8) prompt = ( f"Here is our conversation history. Multiple instructions were given. " f"Find the one that mentions my access code and answer with only that code.\n\n" f"{convo}\n\n" f"What is my access code?" ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=20, ) correct = exact_match_score(ans, secret) preds.append({ "model_answer": ans, "correct": correct, "secret": secret, "depth": depth, }) save_jsonl(os.path.join(out_dir, f"conversation_depth_{depth}.jsonl"), preds) acc = compute_accuracy(preds) results[depth] = {"accuracy": acc, "predictions": preds} logger.info(f"[CONVERSATION] Depth {depth:.1%}: acc={acc:.3f}") summary = { "experiment": "conversation_memory", "num_turns": num_turns, "num_examples": num_examples, "depths": {str(d): results[d]["accuracy"] for d in depths}, "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]), "time_minutes": (time.time() - start) / 60, } save_json(os.path.join(out_dir, "conversation_summary.json"), summary) plot_curve( depths, [results[d]["accuracy"] for d in depths], f"Exp 7: Conversation Memory ({num_turns} turns)", os.path.join(out_dir, "conversation_curve.png"), xlabel="Depth in Chat History (0=start, 1=end)", ) logger.info(f"[CONVERSATION] Time={(time.time()-start)/60:.1f} min") return summary