acb commited on
Commit
732bb64
·
verified ·
1 Parent(s): d69b4dc

Upload retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +319 -0
retrieval.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retrieval modes for EEG semantic decoding.
3
+
4
+ Each mode is a callable class that takes:
5
+ embedding_index: EmbeddingIndex
6
+ nexus_conn: sqlite3 connection
7
+ semantic_embedding: torch.Tensor (the current predicted text embedding)
8
+
9
+ And returns:
10
+ list of strings (lines to print), or empty list (suppress output this frame)
11
+
12
+ Drop into EEGSemanticProcessor by replacing the find_similar_messages + _print_unique_lines
13
+ path with: mode.step(semantic_embedding) -> lines
14
+ """
15
+
16
+ import numpy as np
17
+ import torch
18
+ import hashlib
19
+ import random
20
+ from collections import deque
21
+
22
+
23
+ def fix_encoding(s):
24
+ if not s:
25
+ return s
26
+ if isinstance(s, str):
27
+ b = s.encode('utf-8', 'surrogateescape')
28
+ else:
29
+ b = s
30
+ fixed = b.decode('utf-8', 'replace')
31
+ if 'ì' in s or 'í' in s or 'ï' in s:
32
+ return ""
33
+ return fixed
34
+
35
+
36
+ def _retrieve(embedding_index, nexus_conn, embedding_np, k=64, assistant_only=False):
37
+ """Shared retrieval helper. Returns list of (content, distance) tuples."""
38
+ if len(embedding_np.shape) == 1:
39
+ embedding_np = embedding_np.reshape(1, -1)
40
+
41
+ distances, indices = embedding_index.search(embedding_np, k)
42
+ distances = distances.flatten()
43
+ indices = indices.flatten()
44
+
45
+ cursor = nexus_conn.cursor()
46
+ query = "SELECT content FROM messages WHERE id = ?"
47
+ if assistant_only:
48
+ query += " AND role = 'assistant'"
49
+
50
+ results = []
51
+ for msg_id, dist in zip(indices, distances):
52
+ cursor.execute(query, (int(msg_id),))
53
+ row = cursor.fetchone()
54
+ if row and row[0]:
55
+ results.append((row[0], float(dist)))
56
+ return results
57
+
58
+
59
+ def _lines_from_messages(messages, max_lines=60):
60
+ """Extract individual lines from message contents, deduplicated."""
61
+ lines = []
62
+ seen = set()
63
+ for content in messages:
64
+ for line in content.splitlines():
65
+ line = line.strip()
66
+ if not line:
67
+ continue
68
+ line = fix_encoding(line)
69
+ if not line:
70
+ continue
71
+ if line not in seen:
72
+ seen.add(line)
73
+ lines.append(line)
74
+ if len(lines) >= max_lines:
75
+ return lines
76
+ return lines
77
+
78
+
79
+ class FloodMode:
80
+ """
81
+ Original behavior: retrieve k candidates, sample, deduplicate against
82
+ recent windows. Fast, noisy, good for raw stream-of-consciousness.
83
+ """
84
+
85
+ def __init__(self, embedding_index, nexus_conn, search_k=180, final_k=90,
86
+ sample_size=42, last_n=3):
87
+ self.embedding_index = embedding_index
88
+ self.nexus_conn = nexus_conn
89
+ self.search_k = search_k
90
+ self.final_k = final_k
91
+ self.sample_size = sample_size
92
+ self.previous_sets = deque(maxlen=last_n)
93
+
94
+ def step(self, semantic_embedding):
95
+ emb_np = semantic_embedding.detach().cpu().numpy()
96
+ results = _retrieve(self.embedding_index, self.nexus_conn, emb_np,
97
+ k=self.search_k)
98
+
99
+ messages = [content for content, _ in results[:self.final_k]]
100
+ if not messages:
101
+ return []
102
+
103
+ sample = random.sample(messages, min(self.sample_size, len(messages)))
104
+
105
+ current_lines = set()
106
+ for msg in sample:
107
+ for line in msg.splitlines():
108
+ line = line.strip()
109
+ if line:
110
+ current_lines.add(line)
111
+
112
+ unique = current_lines.copy()
113
+ for prev in self.previous_sets:
114
+ unique -= prev
115
+ self.previous_sets.append(current_lines)
116
+
117
+ unique = [l for l in map(fix_encoding, unique) if l]
118
+ return sorted(unique)
119
+
120
+
121
+ class DriftMode:
122
+ """
123
+ Emit output only when the semantic pointer moves significantly.
124
+ Retrieves based on the *direction* of movement (current - previous),
125
+ added to the current position. This amplifies whatever the signal
126
+ is shifting toward.
127
+
128
+ Parameters:
129
+ move_threshold: minimum cosine distance between consecutive
130
+ embeddings to trigger output
131
+ amplify: how much to weight the delta (1.0 = pure direction,
132
+ 0.0 = pure position)
133
+ search_k: candidates to retrieve
134
+ cooldown: minimum steps between outputs
135
+ """
136
+
137
+ def __init__(self, embedding_index, nexus_conn, search_k=64,
138
+ move_threshold=0.05, amplify=0.5, cooldown=3, max_lines=30):
139
+ self.embedding_index = embedding_index
140
+ self.nexus_conn = nexus_conn
141
+ self.search_k = search_k
142
+ self.move_threshold = move_threshold
143
+ self.amplify = amplify
144
+ self.cooldown = cooldown
145
+ self.max_lines = max_lines
146
+
147
+ self.prev_embedding = None
148
+ self.steps_since_emit = 0
149
+ self.prev_lines = set()
150
+
151
+ def step(self, semantic_embedding):
152
+ emb_np = semantic_embedding.detach().cpu().numpy().flatten()
153
+
154
+ # Normalize
155
+ norm = np.linalg.norm(emb_np)
156
+ if norm > 0:
157
+ emb_normed = emb_np / norm
158
+ else:
159
+ emb_normed = emb_np
160
+
161
+ self.steps_since_emit += 1
162
+
163
+ if self.prev_embedding is None:
164
+ self.prev_embedding = emb_normed
165
+ return []
166
+
167
+ # Compute movement
168
+ cos_sim = np.dot(emb_normed, self.prev_embedding)
169
+ cos_dist = 1.0 - cos_sim
170
+
171
+ if cos_dist < self.move_threshold or self.steps_since_emit < self.cooldown:
172
+ return []
173
+
174
+ # Direction of movement
175
+ delta = emb_normed - self.prev_embedding
176
+ delta_norm = np.linalg.norm(delta)
177
+ if delta_norm > 0:
178
+ delta = delta / delta_norm
179
+
180
+ # Query = current position + amplified direction
181
+ query = emb_normed + self.amplify * delta
182
+ query_norm = np.linalg.norm(query)
183
+ if query_norm > 0:
184
+ query = query / query_norm
185
+
186
+ self.prev_embedding = emb_normed
187
+ self.steps_since_emit = 0
188
+
189
+ results = _retrieve(self.embedding_index, self.nexus_conn,
190
+ query.reshape(1, -1), k=self.search_k)
191
+
192
+ messages = [content for content, _ in results]
193
+ lines = _lines_from_messages(messages, self.max_lines)
194
+
195
+ # Remove lines seen in previous emission
196
+ lines = [l for l in lines if l not in self.prev_lines]
197
+ self.prev_lines = set(lines)
198
+
199
+ return lines
200
+
201
+
202
+ class FocusMode:
203
+ """
204
+ Maintain an exponential moving average of embeddings. Only emit
205
+ when the centroid shifts enough. Surfaces the persistent underlying
206
+ theme rather than moment-to-moment noise.
207
+
208
+ Parameters:
209
+ alpha: EMA smoothing factor (lower = smoother, more stable)
210
+ shift_threshold: minimum cosine distance of centroid movement to emit
211
+ search_k: candidates to retrieve
212
+ top_n: how many top results to show (closest to centroid)
213
+ """
214
+
215
+ def __init__(self, embedding_index, nexus_conn, search_k=48,
216
+ alpha=0.15, shift_threshold=0.02, top_n=20, max_lines=25):
217
+ self.embedding_index = embedding_index
218
+ self.nexus_conn = nexus_conn
219
+ self.search_k = search_k
220
+ self.alpha = alpha
221
+ self.shift_threshold = shift_threshold
222
+ self.top_n = top_n
223
+ self.max_lines = max_lines
224
+
225
+ self.centroid = None
226
+ self.last_emit_centroid = None
227
+ self.prev_lines = set()
228
+
229
+ def step(self, semantic_embedding):
230
+ emb_np = semantic_embedding.detach().cpu().numpy().flatten()
231
+
232
+ norm = np.linalg.norm(emb_np)
233
+ if norm > 0:
234
+ emb_normed = emb_np / norm
235
+ else:
236
+ emb_normed = emb_np
237
+
238
+ # Update EMA centroid
239
+ if self.centroid is None:
240
+ self.centroid = emb_normed.copy()
241
+ self.last_emit_centroid = emb_normed.copy()
242
+ return []
243
+
244
+ self.centroid = self.alpha * emb_normed + (1.0 - self.alpha) * self.centroid
245
+
246
+ # Re-normalize centroid
247
+ c_norm = np.linalg.norm(self.centroid)
248
+ if c_norm > 0:
249
+ centroid_normed = self.centroid / c_norm
250
+ else:
251
+ centroid_normed = self.centroid
252
+
253
+ # Check if centroid has shifted enough since last emission
254
+ cos_sim = np.dot(centroid_normed, self.last_emit_centroid)
255
+ cos_dist = 1.0 - cos_sim
256
+
257
+ if cos_dist < self.shift_threshold:
258
+ return []
259
+
260
+ self.last_emit_centroid = centroid_normed.copy()
261
+
262
+ # Retrieve based on smoothed centroid
263
+ results = _retrieve(self.embedding_index, self.nexus_conn,
264
+ centroid_normed.reshape(1, -1), k=self.search_k)
265
+
266
+ messages = [content for content, _ in results[:self.top_n]]
267
+ lines = _lines_from_messages(messages, self.max_lines)
268
+
269
+ # Deduplicate against previous emission
270
+ lines = [l for l in lines if l not in self.prev_lines]
271
+ self.prev_lines = set(_lines_from_messages(
272
+ [content for content, _ in results[:self.top_n]], self.max_lines))
273
+
274
+ return lines
275
+
276
+
277
+ class LayeredMode:
278
+ """
279
+ Run multiple timescales simultaneously. Show three sections:
280
+
281
+ [fast] — what just changed (high threshold, small k)
282
+ [mid] — recent theme (EMA with medium alpha)
283
+ [slow] — deep undercurrent (EMA with low alpha)
284
+
285
+ Each layer only emits its section when its own threshold is crossed.
286
+ At least one layer must fire for any output.
287
+ """
288
+
289
+ def __init__(self, embedding_index, nexus_conn, search_k=48, max_lines_per_layer=10):
290
+ self.layers = {
291
+ 'fast': DriftMode(embedding_index, nexus_conn, search_k=search_k,
292
+ move_threshold=0.08, amplify=0.7, cooldown=1,
293
+ max_lines=max_lines_per_layer),
294
+ 'mid': FocusMode(embedding_index, nexus_conn, search_k=search_k,
295
+ alpha=0.25, shift_threshold=0.03, top_n=16,
296
+ max_lines=max_lines_per_layer),
297
+ 'slow': FocusMode(embedding_index, nexus_conn, search_k=search_k,
298
+ alpha=0.05, shift_threshold=0.015, top_n=12,
299
+ max_lines=max_lines_per_layer),
300
+ }
301
+
302
+ def step(self, semantic_embedding):
303
+ sections = {}
304
+ for name, layer in self.layers.items():
305
+ lines = layer.step(semantic_embedding)
306
+ if lines:
307
+ sections[name] = lines
308
+
309
+ if not sections:
310
+ return []
311
+
312
+ output = []
313
+ for name in ['fast', 'mid', 'slow']:
314
+ if name in sections:
315
+ output.append(f"── {name} ──")
316
+ output.extend(sections[name])
317
+ output.append("")
318
+
319
+ return output