| |
| """ |
| morphism — EEG-to-text semantic search |
| |
| Usage: |
| morphism record [options] |
| morphism index create|info|rebuild [options] |
| morphism decode [options] |
| """ |
|
|
| import sys |
| import os |
| import argparse |
|
|
| from retrieval import FloodMode, DriftMode, FocusMode, LayeredMode |
|
|
| def cmd_record(args): |
| """Record EEG data from OpenBCI Cyton+Daisy""" |
| from cyton import ( |
| init_board, set_sample_rate, read_complete_packet, process_packet, |
| start_sd_recording, stop_sd_recording, create_ssh_connection, sd_record |
| ) |
| import serial, time, io |
| from datetime import datetime |
|
|
| if args.sd: |
| sd_record(args.port, args.duration, args.sample_rate) |
| return |
|
|
| filename = args.output |
| if filename is None: |
| filename = f"openbci_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" |
|
|
| ser = serial.Serial(args.port, 115200) |
| time.sleep(2) |
| init_board(ser) |
|
|
| if args.sample_rate != 1000: |
| set_sample_rate(ser, args.sample_rate) |
|
|
| ssh, sftp, remote_file = None, None, None |
| if args.remote: |
| ssh = create_ssh_connection() |
| if not ssh: |
| print("SSH connection failed.") |
| return |
| sftp = ssh.open_sftp() |
| remote_file = sftp.open(filename, 'w') |
|
|
| header = "Timestamp," + ",".join(f"Channel{i+1}" for i in range(16)) + "\n" |
| if args.remote: |
| remote_file.write(header) |
| else: |
| with open(filename, 'w') as f: |
| f.write(header) |
|
|
| ser.write(b'b') |
| time.sleep(0.5) |
| ser.reset_input_buffer() |
|
|
| print(f"Recording to {filename} — Ctrl+C to stop") |
|
|
| pkt_count = 0 |
| t0 = time.time() |
| buf = io.StringIO() |
| last_flush = time.time() |
|
|
| try: |
| while True: |
| p1 = read_complete_packet(ser) |
| if not p1: |
| continue |
| p2 = read_complete_packet(ser) |
| if not p2: |
| continue |
|
|
| d1, d2 = process_packet(p1), process_packet(p2) |
| if not (d1 and d2): |
| continue |
|
|
| pkt_count += 1 |
| ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") |
| line = ts + "," + ",".join(f"{x:.6f}" for x in d1 + d2) + "\n" |
|
|
| if args.remote: |
| buf.write(line) |
| if time.time() - last_flush >= 0.1: |
| remote_file.write(buf.getvalue()) |
| buf = io.StringIO() |
| last_flush = time.time() |
| else: |
| with open(filename, 'a') as f: |
| f.write(line) |
|
|
| if pkt_count % 125 == 0: |
| rate = pkt_count / (time.time() - t0) |
| print(f"\r {rate:.1f} Hz, {pkt_count} packets", end='') |
|
|
| if ser.in_waiting > 1000: |
| ser.reset_input_buffer() |
|
|
| except KeyboardInterrupt: |
| ser.write(b's') |
| ser.close() |
| if args.remote: |
| if buf.getvalue(): |
| remote_file.write(buf.getvalue()) |
| remote_file.close() |
| sftp.close() |
| ssh.close() |
|
|
| elapsed = time.time() - t0 |
| print(f"\n\nDone — {pkt_count} packets in {elapsed:.1f}s ({pkt_count/elapsed:.1f} Hz)") |
| print(f"Saved to {filename}") |
|
|
|
|
| def cmd_index(args): |
| """Manage the text embedding index""" |
| from embed import ( |
| get_splitter, process_batch, create_index_if_possible, |
| get_existing_content, INITIAL_BATCH_SIZE, MIN_BATCH_SIZE, SHUFFLE_SEED |
| ) |
| import sqlite3, numpy as np, random |
| from tqdm import tqdm |
|
|
| db_path = os.path.expanduser(args.db) |
| index_prefix = args.index |
|
|
| if args.action == 'info': |
| if not os.path.exists(db_path): |
| print(f"No database at {db_path}") |
| return |
|
|
| conn = sqlite3.connect(db_path) |
| c = conn.cursor() |
| c.execute("SELECT COUNT(*) FROM messages") |
| msg_count = c.fetchone()[0] |
| c.execute("SELECT COUNT(*) FROM embeddings") |
| emb_count = c.fetchone()[0] |
| conn.close() |
|
|
| index_exists = os.path.exists(f"{index_prefix}.index") |
|
|
| print(f"Database: {db_path}") |
| print(f"Messages: {msg_count:,}") |
| print(f"Embeddings: {emb_count:,}") |
| print(f"FAISS index: {'exists' if index_exists else 'not built'} ({index_prefix}.index)") |
| return |
|
|
| if args.action in ('create', 'rebuild'): |
| corpus = os.path.expanduser(args.corpus) |
| if not os.path.isdir(corpus): |
| print(f"Not a directory: {corpus}") |
| sys.exit(1) |
|
|
| splitter = get_splitter(args.split_mode, args.chunk_size, args.chunk_overlap) |
|
|
| print(f"Loading model: {args.model}") |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained(args.model, trust_remote_code=True).cuda() |
| model.eval() |
|
|
| conn = sqlite3.connect(db_path) |
| c = conn.cursor() |
|
|
| if args.action == 'rebuild': |
| print("Dropping existing data...") |
| c.execute("DELETE FROM embeddings") |
| c.execute("DELETE FROM messages") |
| conn.commit() |
|
|
| c.execute("""CREATE TABLE IF NOT EXISTS messages ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, content TEXT, role TEXT)""") |
| c.execute("""CREATE TABLE IF NOT EXISTS embeddings ( |
| message_id INTEGER PRIMARY KEY, embedding BLOB, |
| FOREIGN KEY (message_id) REFERENCES messages(message_id) ON DELETE CASCADE)""") |
| conn.commit() |
| create_index_if_possible(c) |
| conn.commit() |
|
|
| existing = get_existing_content(c) |
| print(f"Already indexed: {len(existing):,}") |
|
|
| txt_files = [f for f in os.listdir(corpus) if f.lower().endswith('.txt')] |
| if not txt_files: |
| print(f"No .txt files in {corpus}") |
| conn.close() |
| return |
|
|
| units = [] |
| for fn in txt_files: |
| with open(os.path.join(corpus, fn), 'r', encoding='utf-8', errors='ignore') as f: |
| units.extend(splitter(f.read())) |
|
|
| random.seed(SHUFFLE_SEED) |
| random.shuffle(units) |
| new_units = [u for u in units if u not in existing] |
| print(f"New units to embed: {len(new_units):,}") |
|
|
| if not new_units: |
| print("Nothing new.") |
| conn.close() |
| return |
|
|
| batch_size = args.batch_size |
| idx = 0 |
| processed = 0 |
|
|
| with tqdm(total=len(new_units), desc="Embedding") as pbar: |
| while idx < len(new_units): |
| batch = new_units[idx:idx + batch_size] |
| ok = process_batch(model, batch, c, args.task) |
| if ok: |
| conn.commit() |
| pbar.update(len(batch)) |
| processed += len(batch) |
| idx += len(batch) |
| else: |
| if batch_size > MIN_BATCH_SIZE: |
| batch_size = max(batch_size // 2, MIN_BATCH_SIZE) |
| print(f"\nOOM — batch size → {batch_size}") |
| else: |
| idx += 1 |
| pbar.update(1) |
| processed += 1 |
|
|
| conn.close() |
| print(f"Embedded {processed:,} units.") |
|
|
| print("Building FAISS index...") |
| _build_faiss_index(db_path, index_prefix) |
| print("Done.") |
|
|
|
|
| def _build_faiss_index(db_path, index_prefix): |
| """Build FAISS index from the embeddings database""" |
| import sqlite3, numpy as np |
| from decode import EmbeddingIndex |
|
|
| conn = sqlite3.connect(db_path) |
| c = conn.cursor() |
| c.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id") |
|
|
| embeddings, ids = [], [] |
| for mid, blob in c.fetchall(): |
| embeddings.append(np.frombuffer(blob, dtype=np.float32)) |
| ids.append(mid) |
| conn.close() |
|
|
| if not embeddings: |
| print(" No embeddings found.") |
| return |
|
|
| embeddings = np.vstack(embeddings) |
| print(f" {len(embeddings):,} vectors, dim={embeddings.shape[1]}") |
|
|
| idx = EmbeddingIndex(dim=embeddings.shape[1]) |
| idx.add_embeddings(embeddings, ids) |
| idx.save(index_prefix) |
|
|
| conn2 = sqlite3.connect(db_path) |
| c2 = conn2.cursor() |
| c2.execute("SELECT COUNT(*) FROM embeddings") |
| count = c2.fetchone()[0] |
| c2.execute("SELECT MAX(message_id) FROM embeddings") |
| max_id = c2.fetchone()[0] |
| conn2.close() |
| np.savez(f"{index_prefix}_metadata.npz", count=count, max_message_id=max_id) |
|
|
|
|
| def cmd_decode(args): |
| """Run EEG → text decoding""" |
| from decode import EEGSemanticProcessor |
|
|
| processor = EEGSemanticProcessor( |
| autoencoder_model_path=args.autoencoder, |
| semantic_model_path=args.semantic, |
| nexus_db_path=args.db, |
| embeddings_db_path=args.db, |
| index_path=args.index, |
| eeg_file_path=args.eeg, |
| window_size=args.window_size, |
| stride=args.stride, |
| batch_size=args.batch_size, |
| device=args.device, |
| search_k=args.search_k, |
| final_k=args.final_k, |
| use_raw_eeg=args.raw_eeg, |
| input_dim_override=args.input_dim, |
| save_vectors=args.save_vectors, |
| vector_output_path=args.vector_output, |
| last_n_messages=args.last_n, |
| ) |
|
|
| modes = { |
| 'flood': lambda: FloodMode(processor.embedding_index, processor.nexus_conn, |
| search_k=args.search_k, final_k=args.final_k, |
| last_n=args.last_n), |
| 'drift': lambda: DriftMode(processor.embedding_index, processor.nexus_conn, |
| search_k=64), |
| 'focus': lambda: FocusMode(processor.embedding_index, processor.nexus_conn, |
| search_k=48), |
| 'layered': lambda: LayeredMode(processor.embedding_index, processor.nexus_conn), |
| } |
|
|
| mode = modes[args.mode]() |
|
|
| processor.eeg_stream.start() |
| try: |
| consecutive_errors = 0 |
| while True: |
| try: |
| for embedding_data in processor.eeg_stream.get_embeddings(timeout=0.5): |
| try: |
| semantic_embedding = processor.process_eeg_embedding( |
| embedding_data['embedding']) |
|
|
| if processor.save_vectors: |
| embedding_np = semantic_embedding.detach().cpu().numpy() |
| processor.vectors_list.append(embedding_np) |
| processor.timestamps.append({ |
| 'start': embedding_data['start_timestamp'], |
| 'end': embedding_data['end_timestamp'] |
| }) |
| if len(processor.vectors_list) % 100 == 0: |
| import logging |
| logging.getLogger("EEGSemanticStream").info( |
| f"Collected {len(processor.vectors_list)} vectors") |
| continue |
|
|
| lines = mode.step(semantic_embedding) |
| if lines: |
| output = "\n".join(lines) |
| print(output) |
| if processor.log_file: |
| processor.log_file.write(output + "\n") |
| processor.log_file.flush() |
|
|
| consecutive_errors = 0 |
| except Exception as e: |
| import sys |
| print(f"Error: {e}", file=sys.stderr) |
| consecutive_errors += 1 |
| if consecutive_errors >= 5: |
| raise RuntimeError("Too many consecutive errors") |
|
|
| import time |
| time.sleep(0.01) |
| except Exception as e: |
| if "Too many" in str(e): |
| raise |
| import sys, time |
| print(f"Error: {e}", file=sys.stderr) |
| consecutive_errors += 1 |
| if consecutive_errors >= 5: |
| raise |
| time.sleep(1) |
| except KeyboardInterrupt: |
| pass |
| except Exception as e: |
| import sys |
| print(f"Fatal: {e}", file=sys.stderr) |
| finally: |
| if processor.save_vectors and processor.vectors_list: |
| processor.save_vectors_to_disk() |
| processor.eeg_stream.stop() |
|
|
| def main(): |
| p = argparse.ArgumentParser( |
| prog='morphism', |
| description='EEG-to-text semantic search', |
| ) |
| sub = p.add_subparsers(dest='command') |
|
|
| |
| rec = sub.add_parser('record', help='Record EEG from OpenBCI Cyton+Daisy') |
| rec.add_argument('--port', '-p', default='/dev/ttyUSB0') |
| rec.add_argument('--output', '-o', default=None) |
| rec.add_argument('--sample-rate', type=int, default=1000) |
| rec.add_argument('--sd', action='store_true', help='Record to SD card') |
| rec.add_argument('--duration', default='G') |
| rec.add_argument('--remote', action='store_true', help='Stream via SSH') |
|
|
| |
| idx = sub.add_parser('index', help='Manage the text embedding index') |
| idx.add_argument('action', choices=['create', 'info', 'rebuild']) |
| idx.add_argument('--corpus', '-c', default=None) |
| idx.add_argument('--db', default='morphism.db') |
| idx.add_argument('--index', default='morphism') |
| idx.add_argument('--split-mode', default='line', |
| choices=['line', 'block', 'sentence', 'chunk']) |
| idx.add_argument('--chunk-size', type=int, default=512) |
| idx.add_argument('--chunk-overlap', type=int, default=64) |
| idx.add_argument('--batch-size', type=int, default=128) |
| idx.add_argument('--task', default='text-matching') |
| idx.add_argument('--model', default='jinaai/jina-embeddings-v3') |
|
|
| |
| dec = sub.add_parser('decode', help='Run EEG → text decoding') |
| dec.add_argument('--mode', default='flood', choices=['flood', 'drift', 'focus', 'layered']) |
| dec.add_argument('--eeg', '-f', required=True) |
| dec.add_argument('--autoencoder', '-a', required=True) |
| dec.add_argument('--semantic', '-s', required=True) |
| dec.add_argument('--db', default='morphism.db') |
| dec.add_argument('--index', default='morphism') |
| dec.add_argument('--window-size', type=int, default=624) |
| dec.add_argument('--stride', type=int, default=32) |
| dec.add_argument('--batch-size', type=int, default=32) |
| dec.add_argument('--device', default=None) |
| dec.add_argument('--search-k', type=int, default=1024) |
| dec.add_argument('--final-k', type=int, default=1024) |
| dec.add_argument('--last-n', type=int, default=128) |
| dec.add_argument('--raw-eeg', action='store_true') |
| dec.add_argument('--input-dim', type=int, default=None) |
| dec.add_argument('--save-vectors', action='store_true') |
| dec.add_argument('--vector-output', default='semantic_vectors.npz') |
|
|
| args = p.parse_args() |
|
|
| if args.command is None: |
| p.print_help() |
| sys.exit(0) |
|
|
| if args.command == 'record': |
| cmd_record(args) |
| elif args.command == 'index': |
| if args.action in ('create', 'rebuild') and not args.corpus: |
| print("--corpus is required for create/rebuild") |
| sys.exit(1) |
| cmd_index(args) |
| elif args.command == 'decode': |
| cmd_decode(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|