File size: 7,847 Bytes
e1e1ce9
 
 
6ce68d3
 
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ce68d3
 
 
 
 
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ce68d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1e1ce9
 
 
6ce68d3
 
e1e1ce9
6ce68d3
 
e1e1ce9
 
 
6ce68d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1e1ce9
 
 
 
 
 
 
 
 
 
6ce68d3
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
6ce68d3
 
 
 
e1e1ce9
6ce68d3
e1e1ce9
6ce68d3
 
 
 
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
Experiment 7: Conversation Memory
Critical instruction buried in long chat history.
FIXED: 300-turn conversations with decoy instructions to prevent trivial keyword search.
Multiple "remember" instructions appear; the model must find the correct one.
"""
import logging
import os
import random
import time
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.metrics import exact_match_score, compute_accuracy, position_bias_index
from src.plotting import plot_curve
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)

USER_MSGS = [
    "Hello, how are you?",
    "What is the weather like today?",
    "Tell me about quantum physics.",
    "Can you recommend a good book?",
    "What are the health benefits of green tea?",
    "Explain how airplanes fly.",
    "What is the history of the internet?",
    "How do I bake sourdough bread?",
    "What are the best hiking trails in Europe?",
    "Explain neural networks simply.",
    "What is blockchain technology?",
    "How does photosynthesis work?",
    "Tell me a joke.",
    "What is the theory of relativity?",
    "How do vaccines work?",
    "What causes earthquakes?",
    "Explain the water cycle.",
    "What is artificial intelligence?",
    "How do I learn a new language?",
    "What are black holes?",
    "What is dark matter?",
    "How do electric cars work?",
    "What is machine learning?",
    "Explain DNA replication.",
    "How does 3D printing work?",
]

ASSISTANT_MSGS = [
    "I'm doing well, thank you!",
    "The weather varies by location and season.",
    "Quantum physics studies matter at the smallest scales.",
    "I recommend 'Sapiens' by Yuval Noah Harari.",
    "Green tea contains antioxidants that may boost metabolism.",
    "Airplanes fly due to lift generated by their wings.",
    "The internet evolved from ARPANET in the 1960s.",
    "Sourdough requires flour, water, salt, and a starter culture.",
    "The Tour du Mont Blanc is a spectacular alpine trail.",
    "Neural networks learn patterns from data through layers.",
    "Blockchain is a decentralized digital ledger.",
    "Plants convert sunlight into chemical energy.",
    "Why don't scientists trust atoms? Because they make up everything!",
    "Relativity describes how space and time are interconnected.",
    "Vaccines train the immune system to recognize pathogens.",
    "Earthquakes occur when tectonic plates shift.",
    "Water evaporates, condenses, and precipitates in a cycle.",
    "AI enables machines to perform tasks requiring human intelligence.",
    "Practice daily, immerse yourself, and use spaced repetition.",
    "Black holes have gravitational fields so strong nothing escapes.",
    "Dark matter does not emit light but exerts gravitational pull.",
    "Electric cars use battery packs to power electric motors.",
    "Machine learning is a subset of AI focused on pattern recognition.",
    "DNA replication ensures genetic information is copied accurately.",
    "3D printing builds objects layer by layer from digital models.",
]

# Decoy instructions that use the same pattern but different values
DECOY_INSTRUCTIONS = [
    "Please remember my backup phone number is {}. This is important.",
    "Please remember my work ID is {}. This is important.",
    "Please remember my emergency contact is {}. This is important.",
    "Please remember my security question answer is {}. This is important.",
    "Please remember my alternate email is {}. This is important.",
    "Please remember my locker combination is {}. This is important.",
    "Please remember my membership ID is {}. This is important.",
    "Please remember my API key is {}. This is important.",
    "Please remember my passport number is {}. This is important.",
    "Please remember my account PIN is {}. This is important.",
]


def _make_conversation(num_turns: int, target_instruction: str, ratio: float, num_decoys: int = 8) -> str:
    """Build conversation with decoy instructions and target at specified depth."""
    convo = []
    
    # Generate base turns
    for i in range(num_turns):
        convo.append(f"User: {random.choice(USER_MSGS)}")
        convo.append(f"Assistant: {random.choice(ASSISTANT_MSGS)}")
    
    # Insert decoy instructions at random positions (avoid target position)
    target_idx = int(ratio * len(convo))
    decoy_positions = set()
    for _ in range(num_decoys):
        pos = random.randint(0, len(convo) - 1)
        # Don't place decoy right at target position
        if abs(pos - target_idx) > 5:
            decoy_positions.add(pos)
    
    decoy_values = []
    for pos in sorted(decoy_positions):
        tmpl = random.choice(DECOY_INSTRUCTIONS)
        val = f"DECOY-{random.randint(1000, 9999)}"
        decoy_values.append(val)
        convo.insert(pos, f"User: {tmpl.format(val)}")
        convo.insert(pos + 1, "Assistant: I will remember that.")
    
    # Insert target instruction
    convo.insert(target_idx, f"User: {target_instruction}")
    convo.insert(target_idx + 1, "Assistant: I will remember that.")
    
    return "\n\n".join(convo)


def run_conversation_memory(
    model_name: str,
    num_turns: int,
    num_examples: int,
    out_dir: str,
    depths: List[float] = None,
) -> Dict[str, Any]:
    """Run conversation memory experiment with decoy instructions."""
    ensure_dir(out_dir)

    if depths is None:
        depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[CONVERSATION] Depth {depth:.1%}")
        preds = []
        for i in tqdm(range(num_examples), desc=f"Conversation {depth:.1%}", leave=False):
            secret = f"SEC-{random.randint(10000, 99999)}"
            target_instruction = (
                f"Please always remember that my access code is {secret}. "
                f"This is extremely important for future questions."
            )
            convo = _make_conversation(num_turns, target_instruction, depth, num_decoys=8)
            prompt = (
                f"Here is our conversation history. Multiple instructions were given. "
                f"Find the one that mentions my access code and answer with only that code.\n\n"
                f"{convo}\n\n"
                f"What is my access code?"
            )
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=20,
            )
            correct = exact_match_score(ans, secret)
            preds.append({
                "model_answer": ans,
                "correct": correct,
                "secret": secret,
                "depth": depth,
            })

        save_jsonl(os.path.join(out_dir, f"conversation_depth_{depth}.jsonl"), preds)
        acc = compute_accuracy(preds)
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[CONVERSATION] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": "conversation_memory",
        "num_turns": num_turns,
        "num_examples": num_examples,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "pbi": position_bias_index(depths, [results[d]["accuracy"] for d in depths]),
        "time_minutes": (time.time() - start) / 60,
    }

    save_json(os.path.join(out_dir, "conversation_summary.json"), summary)
    plot_curve(
        depths,
        [results[d]["accuracy"] for d in depths],
        f"Exp 7: Conversation Memory ({num_turns} turns)",
        os.path.join(out_dir, "conversation_curve.png"),
        xlabel="Depth in Chat History (0=start, 1=end)",
    )

    logger.info(f"[CONVERSATION] Time={(time.time()-start)/60:.1f} min")
    return summary