""" Retrieval modes for EEG semantic decoding. Each mode is a callable class that takes: embedding_index: EmbeddingIndex nexus_conn: sqlite3 connection semantic_embedding: torch.Tensor (the current predicted text embedding) And returns: list of strings (lines to print), or empty list (suppress output this frame) Drop into EEGSemanticProcessor by replacing the find_similar_messages + _print_unique_lines path with: mode.step(semantic_embedding) -> lines """ import numpy as np import torch import hashlib import random from collections import deque def fix_encoding(s): if not s: return s if isinstance(s, str): b = s.encode('utf-8', 'surrogateescape') else: b = s fixed = b.decode('utf-8', 'replace') if 'ì' in s or 'í' in s or 'ï' in s: return "" return fixed def _retrieve(embedding_index, nexus_conn, embedding_np, k=64, assistant_only=False): """Shared retrieval helper. Returns list of (content, distance) tuples.""" if len(embedding_np.shape) == 1: embedding_np = embedding_np.reshape(1, -1) distances, indices = embedding_index.search(embedding_np, k) distances = distances.flatten() indices = indices.flatten() cursor = nexus_conn.cursor() query = "SELECT content FROM messages WHERE id = ?" if assistant_only: query += " AND role = 'assistant'" results = [] for msg_id, dist in zip(indices, distances): cursor.execute(query, (int(msg_id),)) row = cursor.fetchone() if row and row[0]: results.append((row[0], float(dist))) return results def _lines_from_messages(messages, max_lines=60): """Extract individual lines from message contents, deduplicated.""" lines = [] seen = set() for content in messages: for line in content.splitlines(): line = line.strip() if not line: continue line = fix_encoding(line) if not line: continue if line not in seen: seen.add(line) lines.append(line) if len(lines) >= max_lines: return lines return lines class FloodMode: """ Original behavior: retrieve k candidates, sample, deduplicate against recent windows. Fast, noisy, good for raw stream-of-consciousness. """ def __init__(self, embedding_index, nexus_conn, search_k=180, final_k=90, sample_size=42, last_n=3): self.embedding_index = embedding_index self.nexus_conn = nexus_conn self.search_k = search_k self.final_k = final_k self.sample_size = sample_size self.previous_sets = deque(maxlen=last_n) def step(self, semantic_embedding): emb_np = semantic_embedding.detach().cpu().numpy() results = _retrieve(self.embedding_index, self.nexus_conn, emb_np, k=self.search_k) messages = [content for content, _ in results[:self.final_k]] if not messages: return [] sample = random.sample(messages, min(self.sample_size, len(messages))) current_lines = set() for msg in sample: for line in msg.splitlines(): line = line.strip() if line: current_lines.add(line) unique = current_lines.copy() for prev in self.previous_sets: unique -= prev self.previous_sets.append(current_lines) unique = [l for l in map(fix_encoding, unique) if l] return sorted(unique) class DriftMode: """ Emit output only when the semantic pointer moves significantly. Retrieves based on the *direction* of movement (current - previous), added to the current position. This amplifies whatever the signal is shifting toward. Parameters: move_threshold: minimum cosine distance between consecutive embeddings to trigger output amplify: how much to weight the delta (1.0 = pure direction, 0.0 = pure position) search_k: candidates to retrieve cooldown: minimum steps between outputs """ def __init__(self, embedding_index, nexus_conn, search_k=64, move_threshold=0.05, amplify=0.5, cooldown=3, max_lines=30): self.embedding_index = embedding_index self.nexus_conn = nexus_conn self.search_k = search_k self.move_threshold = move_threshold self.amplify = amplify self.cooldown = cooldown self.max_lines = max_lines self.prev_embedding = None self.steps_since_emit = 0 self.prev_lines = set() def step(self, semantic_embedding): emb_np = semantic_embedding.detach().cpu().numpy().flatten() # Normalize norm = np.linalg.norm(emb_np) if norm > 0: emb_normed = emb_np / norm else: emb_normed = emb_np self.steps_since_emit += 1 if self.prev_embedding is None: self.prev_embedding = emb_normed return [] # Compute movement cos_sim = np.dot(emb_normed, self.prev_embedding) cos_dist = 1.0 - cos_sim if cos_dist < self.move_threshold or self.steps_since_emit < self.cooldown: return [] # Direction of movement delta = emb_normed - self.prev_embedding delta_norm = np.linalg.norm(delta) if delta_norm > 0: delta = delta / delta_norm # Query = current position + amplified direction query = emb_normed + self.amplify * delta query_norm = np.linalg.norm(query) if query_norm > 0: query = query / query_norm self.prev_embedding = emb_normed self.steps_since_emit = 0 results = _retrieve(self.embedding_index, self.nexus_conn, query.reshape(1, -1), k=self.search_k) messages = [content for content, _ in results] lines = _lines_from_messages(messages, self.max_lines) # Remove lines seen in previous emission lines = [l for l in lines if l not in self.prev_lines] self.prev_lines = set(lines) return lines class FocusMode: """ Maintain an exponential moving average of embeddings. Only emit when the centroid shifts enough. Surfaces the persistent underlying theme rather than moment-to-moment noise. Parameters: alpha: EMA smoothing factor (lower = smoother, more stable) shift_threshold: minimum cosine distance of centroid movement to emit search_k: candidates to retrieve top_n: how many top results to show (closest to centroid) """ def __init__(self, embedding_index, nexus_conn, search_k=48, alpha=0.15, shift_threshold=0.02, top_n=20, max_lines=25): self.embedding_index = embedding_index self.nexus_conn = nexus_conn self.search_k = search_k self.alpha = alpha self.shift_threshold = shift_threshold self.top_n = top_n self.max_lines = max_lines self.centroid = None self.last_emit_centroid = None self.prev_lines = set() def step(self, semantic_embedding): emb_np = semantic_embedding.detach().cpu().numpy().flatten() norm = np.linalg.norm(emb_np) if norm > 0: emb_normed = emb_np / norm else: emb_normed = emb_np # Update EMA centroid if self.centroid is None: self.centroid = emb_normed.copy() self.last_emit_centroid = emb_normed.copy() return [] self.centroid = self.alpha * emb_normed + (1.0 - self.alpha) * self.centroid # Re-normalize centroid c_norm = np.linalg.norm(self.centroid) if c_norm > 0: centroid_normed = self.centroid / c_norm else: centroid_normed = self.centroid # Check if centroid has shifted enough since last emission cos_sim = np.dot(centroid_normed, self.last_emit_centroid) cos_dist = 1.0 - cos_sim if cos_dist < self.shift_threshold: return [] self.last_emit_centroid = centroid_normed.copy() # Retrieve based on smoothed centroid results = _retrieve(self.embedding_index, self.nexus_conn, centroid_normed.reshape(1, -1), k=self.search_k) messages = [content for content, _ in results[:self.top_n]] lines = _lines_from_messages(messages, self.max_lines) # Deduplicate against previous emission lines = [l for l in lines if l not in self.prev_lines] self.prev_lines = set(_lines_from_messages( [content for content, _ in results[:self.top_n]], self.max_lines)) return lines class LayeredMode: """ Run multiple timescales simultaneously. Show three sections: [fast] — what just changed (high threshold, small k) [mid] — recent theme (EMA with medium alpha) [slow] — deep undercurrent (EMA with low alpha) Each layer only emits its section when its own threshold is crossed. At least one layer must fire for any output. """ def __init__(self, embedding_index, nexus_conn, search_k=48, max_lines_per_layer=10): self.layers = { 'fast': DriftMode(embedding_index, nexus_conn, search_k=search_k, move_threshold=0.08, amplify=0.7, cooldown=1, max_lines=max_lines_per_layer), 'mid': FocusMode(embedding_index, nexus_conn, search_k=search_k, alpha=0.25, shift_threshold=0.03, top_n=16, max_lines=max_lines_per_layer), 'slow': FocusMode(embedding_index, nexus_conn, search_k=search_k, alpha=0.05, shift_threshold=0.015, top_n=12, max_lines=max_lines_per_layer), } def step(self, semantic_embedding): sections = {} for name, layer in self.layers.items(): lines = layer.step(semantic_embedding) if lines: sections[name] = lines if not sections: return [] output = [] for name in ['fast', 'mid', 'slow']: if name in sections: output.append(f"── {name} ──") output.extend(sections[name]) output.append("") return output