| """ |
| Experiment 7: Conversation Memory |
| Critical instruction buried in long chat history. |
| FIXED: 300-turn conversations with decoy instructions to prevent trivial keyword search. |
| Multiple "remember" instructions appear; the model must find the correct one. |
| """ |
| import logging |
| import os |
| import random |
| import time |
| from typing import List, Dict, Any |
|
|
| from tqdm import tqdm |
|
|
| from src.generator import generate_text |
| from src.metrics import exact_match_score, compute_accuracy, position_bias_index |
| from src.plotting import plot_curve |
| from src.utils import ensure_dir, save_jsonl, save_json |
|
|
| logger = logging.getLogger(__name__) |
|
|
| USER_MSGS = [ |
| "Hello, how are you?", |
| "What is the weather like today?", |
| "Tell me about quantum physics.", |
| "Can you recommend a good book?", |
| "What are the health benefits of green tea?", |
| "Explain how airplanes fly.", |
| "What is the history of the internet?", |
| "How do I bake sourdough bread?", |
| "What are the best hiking trails in Europe?", |
| "Explain neural networks simply.", |
| "What is blockchain technology?", |
| "How does photosynthesis work?", |
| "Tell me a joke.", |
| "What is the theory of relativity?", |
| "How do vaccines work?", |
| "What causes earthquakes?", |
| "Explain the water cycle.", |
| "What is artificial intelligence?", |
| "How do I learn a new language?", |
| "What are black holes?", |
| "What is dark matter?", |
| "How do electric cars work?", |
| "What is machine learning?", |
| "Explain DNA replication.", |
| "How does 3D printing work?", |
| ] |
|
|
| ASSISTANT_MSGS = [ |
| "I'm doing well, thank you!", |
| "The weather varies by location and season.", |
| "Quantum physics studies matter at the smallest scales.", |
| "I recommend 'Sapiens' by Yuval Noah Harari.", |
| "Green tea contains antioxidants that may boost metabolism.", |
| "Airplanes fly due to lift generated by their wings.", |
| "The internet evolved from ARPANET in the 1960s.", |
| "Sourdough requires flour, water, salt, and a starter culture.", |
| "The Tour du Mont Blanc is a spectacular alpine trail.", |
| "Neural networks learn patterns from data through layers.", |
| "Blockchain is a decentralized digital ledger.", |
| "Plants convert sunlight into chemical energy.", |
| "Why don't scientists trust atoms? Because they make up everything!", |
| "Relativity describes how space and time are interconnected.", |
| "Vaccines train the immune system to recognize pathogens.", |
| "Earthquakes occur when tectonic plates shift.", |
| "Water evaporates, condenses, and precipitates in a cycle.", |
| "AI enables machines to perform tasks requiring human intelligence.", |
| "Practice daily, immerse yourself, and use spaced repetition.", |
| "Black holes have gravitational fields so strong nothing escapes.", |
| "Dark matter does not emit light but exerts gravitational pull.", |
| "Electric cars use battery packs to power electric motors.", |
| "Machine learning is a subset of AI focused on pattern recognition.", |
| "DNA replication ensures genetic information is copied accurately.", |
| "3D printing builds objects layer by layer from digital models.", |
| ] |
|
|
| |
| DECOY_INSTRUCTIONS = [ |
| "Please remember my backup phone number is {}. This is important.", |
| "Please remember my work ID is {}. This is important.", |
| "Please remember my emergency contact is {}. This is important.", |
| "Please remember my security question answer is {}. This is important.", |
| "Please remember my alternate email is {}. This is important.", |
| "Please remember my locker combination is {}. This is important.", |
| "Please remember my membership ID is {}. This is important.", |
| "Please remember my API key is {}. This is important.", |
| "Please remember my passport number is {}. This is important.", |
| "Please remember my account PIN is {}. This is important.", |
| ] |
|
|
|
|
| def _make_conversation(num_turns: int, target_instruction: str, ratio: float, num_decoys: int = 8) -> str: |
| """Build conversation with decoy instructions and target at specified depth.""" |
| convo = [] |
| |
| |
| for i in range(num_turns): |
| convo.append(f"User: {random.choice(USER_MSGS)}") |
| convo.append(f"Assistant: {random.choice(ASSISTANT_MSGS)}") |
| |
| |
| target_idx = int(ratio * len(convo)) |
| decoy_positions = set() |
| for _ in range(num_decoys): |
| pos = random.randint(0, len(convo) - 1) |
| |
| if abs(pos - target_idx) > 5: |
| decoy_positions.add(pos) |
| |
| decoy_values = [] |
| for pos in sorted(decoy_positions): |
| tmpl = random.choice(DECOY_INSTRUCTIONS) |
| val = f"DECOY-{random.randint(1000, 9999)}" |
| decoy_values.append(val) |
| convo.insert(pos, f"User: {tmpl.format(val)}") |
| convo.insert(pos + 1, "Assistant: I will remember that.") |
| |
| |
| convo.insert(target_idx, f"User: {target_instruction}") |
| convo.insert(target_idx + 1, "Assistant: I will remember that.") |
| |
| return "\n\n".join(convo) |
|
|
|
|
| def run_conversation_memory( |
| model_name: str, |
| num_turns: int, |
| num_examples: int, |
| out_dir: str, |
| depths: List[float] = None, |
| ) -> Dict[str, Any]: |
| """Run conversation memory experiment with decoy instructions.""" |
| ensure_dir(out_dir) |
|
|
| if depths is None: |
| depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] |
|
|
| results = {} |
| start = time.time() |
|
|
| for depth in depths: |
| logger.info(f"[CONVERSATION] Depth {depth:.1%}") |
| preds = [] |
| for i in tqdm(range(num_examples), desc=f"Conversation {depth:.1%}", leave=False): |
| secret = f"SEC-{random.randint(10000, 99999)}" |
| target_instruction = ( |
| f"Please always remember that my access code is {secret}. " |
| f"This is extremely important for future questions." |
| ) |
| convo = _make_conversation(num_turns, target_instruction, depth, num_decoys=8) |
| prompt = ( |
| f"Here is our conversation history. Multiple instructions were given. " |
| f"Find the one that mentions my access code and answer with only that code.\n\n" |
| f"{convo}\n\n" |
| f"What is my access code?" |
| ) |
| ans = generate_text( |
| [{"role": "user", "content": prompt}], |
| model_name=model_name, |
| max_new_tokens=20, |
| ) |
| correct = exact_match_score(ans, secret) |
| preds.append({ |
| "model_answer": ans, |
| "correct": correct, |
| "secret": secret, |
| "depth": depth, |
| }) |
|
|
| save_jsonl(os.path.join(out_dir, f"conversation_depth_{depth}.jsonl"), preds) |
| acc = compute_accuracy(preds) |
| results[depth] = {"accuracy": acc, "predictions": preds} |
| logger.info(f"[CONVERSATION] Depth {depth:.1%}: acc={acc:.3f}") |
|
|
| summary = { |
| "experiment": "conversation_memory", |
| "num_turns": num_turns, |
| "num_examples": num_examples, |
| "depths": {str(d): results[d]["accuracy"] for d in depths}, |
| "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]), |
| "time_minutes": (time.time() - start) / 60, |
| } |
|
|
| save_json(os.path.join(out_dir, "conversation_summary.json"), summary) |
| plot_curve( |
| depths, |
| [results[d]["accuracy"] for d in depths], |
| f"Exp 7: Conversation Memory ({num_turns} turns)", |
| os.path.join(out_dir, "conversation_curve.png"), |
| xlabel="Depth in Chat History (0=start, 1=end)", |
| ) |
|
|
| logger.info(f"[CONVERSATION] Time={(time.time()-start)/60:.1f} min") |
| return summary |
|
|