morphism / morphism.py
acb's picture
Upload 5 files
5e284bb verified
#!/usr/bin/env python3
"""
morphism — EEG-to-text semantic search
Usage:
morphism record [options]
morphism index create|info|rebuild [options]
morphism decode [options]
"""
import sys
import os
import argparse
from retrieval import FloodMode, DriftMode, FocusMode, LayeredMode
def cmd_record(args):
"""Record EEG data from OpenBCI Cyton+Daisy"""
from cyton import (
init_board, set_sample_rate, read_complete_packet, process_packet,
start_sd_recording, stop_sd_recording, create_ssh_connection, sd_record
)
import serial, time, io
from datetime import datetime
if args.sd:
sd_record(args.port, args.duration, args.sample_rate)
return
filename = args.output
if filename is None:
filename = f"openbci_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
ser = serial.Serial(args.port, 115200)
time.sleep(2)
init_board(ser)
if args.sample_rate != 1000:
set_sample_rate(ser, args.sample_rate)
ssh, sftp, remote_file = None, None, None
if args.remote:
ssh = create_ssh_connection()
if not ssh:
print("SSH connection failed.")
return
sftp = ssh.open_sftp()
remote_file = sftp.open(filename, 'w')
header = "Timestamp," + ",".join(f"Channel{i+1}" for i in range(16)) + "\n"
if args.remote:
remote_file.write(header)
else:
with open(filename, 'w') as f:
f.write(header)
ser.write(b'b')
time.sleep(0.5)
ser.reset_input_buffer()
print(f"Recording to {filename} — Ctrl+C to stop")
pkt_count = 0
t0 = time.time()
buf = io.StringIO()
last_flush = time.time()
try:
while True:
p1 = read_complete_packet(ser)
if not p1:
continue
p2 = read_complete_packet(ser)
if not p2:
continue
d1, d2 = process_packet(p1), process_packet(p2)
if not (d1 and d2):
continue
pkt_count += 1
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
line = ts + "," + ",".join(f"{x:.6f}" for x in d1 + d2) + "\n"
if args.remote:
buf.write(line)
if time.time() - last_flush >= 0.1:
remote_file.write(buf.getvalue())
buf = io.StringIO()
last_flush = time.time()
else:
with open(filename, 'a') as f:
f.write(line)
if pkt_count % 125 == 0:
rate = pkt_count / (time.time() - t0)
print(f"\r {rate:.1f} Hz, {pkt_count} packets", end='')
if ser.in_waiting > 1000:
ser.reset_input_buffer()
except KeyboardInterrupt:
ser.write(b's')
ser.close()
if args.remote:
if buf.getvalue():
remote_file.write(buf.getvalue())
remote_file.close()
sftp.close()
ssh.close()
elapsed = time.time() - t0
print(f"\n\nDone — {pkt_count} packets in {elapsed:.1f}s ({pkt_count/elapsed:.1f} Hz)")
print(f"Saved to {filename}")
def cmd_index(args):
"""Manage the text embedding index"""
from embed import (
get_splitter, process_batch, create_index_if_possible,
get_existing_content, INITIAL_BATCH_SIZE, MIN_BATCH_SIZE, SHUFFLE_SEED
)
import sqlite3, numpy as np, random
from tqdm import tqdm
db_path = os.path.expanduser(args.db)
index_prefix = args.index
if args.action == 'info':
if not os.path.exists(db_path):
print(f"No database at {db_path}")
return
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute("SELECT COUNT(*) FROM messages")
msg_count = c.fetchone()[0]
c.execute("SELECT COUNT(*) FROM embeddings")
emb_count = c.fetchone()[0]
conn.close()
index_exists = os.path.exists(f"{index_prefix}.index")
print(f"Database: {db_path}")
print(f"Messages: {msg_count:,}")
print(f"Embeddings: {emb_count:,}")
print(f"FAISS index: {'exists' if index_exists else 'not built'} ({index_prefix}.index)")
return
if args.action in ('create', 'rebuild'):
corpus = os.path.expanduser(args.corpus)
if not os.path.isdir(corpus):
print(f"Not a directory: {corpus}")
sys.exit(1)
splitter = get_splitter(args.split_mode, args.chunk_size, args.chunk_overlap)
print(f"Loading model: {args.model}")
from transformers import AutoModel
model = AutoModel.from_pretrained(args.model, trust_remote_code=True).cuda()
model.eval()
conn = sqlite3.connect(db_path)
c = conn.cursor()
if args.action == 'rebuild':
print("Dropping existing data...")
c.execute("DELETE FROM embeddings")
c.execute("DELETE FROM messages")
conn.commit()
c.execute("""CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT, content TEXT, role TEXT)""")
c.execute("""CREATE TABLE IF NOT EXISTS embeddings (
message_id INTEGER PRIMARY KEY, embedding BLOB,
FOREIGN KEY (message_id) REFERENCES messages(message_id) ON DELETE CASCADE)""")
conn.commit()
create_index_if_possible(c)
conn.commit()
existing = get_existing_content(c)
print(f"Already indexed: {len(existing):,}")
txt_files = [f for f in os.listdir(corpus) if f.lower().endswith('.txt')]
if not txt_files:
print(f"No .txt files in {corpus}")
conn.close()
return
units = []
for fn in txt_files:
with open(os.path.join(corpus, fn), 'r', encoding='utf-8', errors='ignore') as f:
units.extend(splitter(f.read()))
random.seed(SHUFFLE_SEED)
random.shuffle(units)
new_units = [u for u in units if u not in existing]
print(f"New units to embed: {len(new_units):,}")
if not new_units:
print("Nothing new.")
conn.close()
return
batch_size = args.batch_size
idx = 0
processed = 0
with tqdm(total=len(new_units), desc="Embedding") as pbar:
while idx < len(new_units):
batch = new_units[idx:idx + batch_size]
ok = process_batch(model, batch, c, args.task)
if ok:
conn.commit()
pbar.update(len(batch))
processed += len(batch)
idx += len(batch)
else:
if batch_size > MIN_BATCH_SIZE:
batch_size = max(batch_size // 2, MIN_BATCH_SIZE)
print(f"\nOOM — batch size → {batch_size}")
else:
idx += 1
pbar.update(1)
processed += 1
conn.close()
print(f"Embedded {processed:,} units.")
print("Building FAISS index...")
_build_faiss_index(db_path, index_prefix)
print("Done.")
def _build_faiss_index(db_path, index_prefix):
"""Build FAISS index from the embeddings database"""
import sqlite3, numpy as np
from decode import EmbeddingIndex
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id")
embeddings, ids = [], []
for mid, blob in c.fetchall():
embeddings.append(np.frombuffer(blob, dtype=np.float32))
ids.append(mid)
conn.close()
if not embeddings:
print(" No embeddings found.")
return
embeddings = np.vstack(embeddings)
print(f" {len(embeddings):,} vectors, dim={embeddings.shape[1]}")
idx = EmbeddingIndex(dim=embeddings.shape[1])
idx.add_embeddings(embeddings, ids)
idx.save(index_prefix)
conn2 = sqlite3.connect(db_path)
c2 = conn2.cursor()
c2.execute("SELECT COUNT(*) FROM embeddings")
count = c2.fetchone()[0]
c2.execute("SELECT MAX(message_id) FROM embeddings")
max_id = c2.fetchone()[0]
conn2.close()
np.savez(f"{index_prefix}_metadata.npz", count=count, max_message_id=max_id)
def cmd_decode(args):
"""Run EEG → text decoding"""
from decode import EEGSemanticProcessor
processor = EEGSemanticProcessor(
autoencoder_model_path=args.autoencoder,
semantic_model_path=args.semantic,
nexus_db_path=args.db,
embeddings_db_path=args.db,
index_path=args.index,
eeg_file_path=args.eeg,
window_size=args.window_size,
stride=args.stride,
batch_size=args.batch_size,
device=args.device,
search_k=args.search_k,
final_k=args.final_k,
use_raw_eeg=args.raw_eeg,
input_dim_override=args.input_dim,
save_vectors=args.save_vectors,
vector_output_path=args.vector_output,
last_n_messages=args.last_n,
)
modes = {
'flood': lambda: FloodMode(processor.embedding_index, processor.nexus_conn,
search_k=args.search_k, final_k=args.final_k,
last_n=args.last_n),
'drift': lambda: DriftMode(processor.embedding_index, processor.nexus_conn,
search_k=64),
'focus': lambda: FocusMode(processor.embedding_index, processor.nexus_conn,
search_k=48),
'layered': lambda: LayeredMode(processor.embedding_index, processor.nexus_conn),
}
mode = modes[args.mode]()
processor.eeg_stream.start()
try:
consecutive_errors = 0
while True:
try:
for embedding_data in processor.eeg_stream.get_embeddings(timeout=0.5):
try:
semantic_embedding = processor.process_eeg_embedding(
embedding_data['embedding'])
if processor.save_vectors:
embedding_np = semantic_embedding.detach().cpu().numpy()
processor.vectors_list.append(embedding_np)
processor.timestamps.append({
'start': embedding_data['start_timestamp'],
'end': embedding_data['end_timestamp']
})
if len(processor.vectors_list) % 100 == 0:
import logging
logging.getLogger("EEGSemanticStream").info(
f"Collected {len(processor.vectors_list)} vectors")
continue
lines = mode.step(semantic_embedding)
if lines:
output = "\n".join(lines)
print(output)
if processor.log_file:
processor.log_file.write(output + "\n")
processor.log_file.flush()
consecutive_errors = 0
except Exception as e:
import sys
print(f"Error: {e}", file=sys.stderr)
consecutive_errors += 1
if consecutive_errors >= 5:
raise RuntimeError("Too many consecutive errors")
import time
time.sleep(0.01)
except Exception as e:
if "Too many" in str(e):
raise
import sys, time
print(f"Error: {e}", file=sys.stderr)
consecutive_errors += 1
if consecutive_errors >= 5:
raise
time.sleep(1)
except KeyboardInterrupt:
pass
except Exception as e:
import sys
print(f"Fatal: {e}", file=sys.stderr)
finally:
if processor.save_vectors and processor.vectors_list:
processor.save_vectors_to_disk()
processor.eeg_stream.stop()
def main():
p = argparse.ArgumentParser(
prog='morphism',
description='EEG-to-text semantic search',
)
sub = p.add_subparsers(dest='command')
# --- record ---
rec = sub.add_parser('record', help='Record EEG from OpenBCI Cyton+Daisy')
rec.add_argument('--port', '-p', default='/dev/ttyUSB0')
rec.add_argument('--output', '-o', default=None)
rec.add_argument('--sample-rate', type=int, default=1000)
rec.add_argument('--sd', action='store_true', help='Record to SD card')
rec.add_argument('--duration', default='G')
rec.add_argument('--remote', action='store_true', help='Stream via SSH')
# --- index ---
idx = sub.add_parser('index', help='Manage the text embedding index')
idx.add_argument('action', choices=['create', 'info', 'rebuild'])
idx.add_argument('--corpus', '-c', default=None)
idx.add_argument('--db', default='morphism.db')
idx.add_argument('--index', default='morphism')
idx.add_argument('--split-mode', default='line',
choices=['line', 'block', 'sentence', 'chunk'])
idx.add_argument('--chunk-size', type=int, default=512)
idx.add_argument('--chunk-overlap', type=int, default=64)
idx.add_argument('--batch-size', type=int, default=128)
idx.add_argument('--task', default='text-matching')
idx.add_argument('--model', default='jinaai/jina-embeddings-v3')
# --- decode ---
dec = sub.add_parser('decode', help='Run EEG → text decoding')
dec.add_argument('--mode', default='flood', choices=['flood', 'drift', 'focus', 'layered'])
dec.add_argument('--eeg', '-f', required=True)
dec.add_argument('--autoencoder', '-a', required=True)
dec.add_argument('--semantic', '-s', required=True)
dec.add_argument('--db', default='morphism.db')
dec.add_argument('--index', default='morphism')
dec.add_argument('--window-size', type=int, default=624)
dec.add_argument('--stride', type=int, default=32)
dec.add_argument('--batch-size', type=int, default=32)
dec.add_argument('--device', default=None)
dec.add_argument('--search-k', type=int, default=1024)
dec.add_argument('--final-k', type=int, default=1024)
dec.add_argument('--last-n', type=int, default=128)
dec.add_argument('--raw-eeg', action='store_true')
dec.add_argument('--input-dim', type=int, default=None)
dec.add_argument('--save-vectors', action='store_true')
dec.add_argument('--vector-output', default='semantic_vectors.npz')
args = p.parse_args()
if args.command is None:
p.print_help()
sys.exit(0)
if args.command == 'record':
cmd_record(args)
elif args.command == 'index':
if args.action in ('create', 'rebuild') and not args.corpus:
print("--corpus is required for create/rebuild")
sys.exit(1)
cmd_index(args)
elif args.command == 'decode':
cmd_decode(args)
if __name__ == '__main__':
main()