import os import time import csv import numpy as np import torch import threading import queue from datetime import datetime from pathlib import Path from typing import List, Dict, Tuple, Optional, Callable, Generator, Union, Any import logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("EEGStream") class EncoderExtractor: def __init__(self, model_path, device=None, force_sequence_length=None): if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = device self.force_sequence_length = force_sequence_length logger.info(f"Loading traced encoder from {model_path} to {self.device}") self.model = torch.jit.load(model_path, map_location=self.device) self.model.eval() dummy = torch.randn(1, 16, force_sequence_length or 624, device=self.device) with torch.no_grad(): self._embedding_size = self.model(dummy).shape[1] logger.info(f"Embedding size: {self._embedding_size}") def get_embedding_size(self): return self._embedding_size def embed(self, data): if self.force_sequence_length and data.shape[2] != self.force_sequence_length: data = torch.nn.functional.interpolate( data, size=self.force_sequence_length, mode='linear', align_corners=False ) with torch.no_grad(): return self.model(data.to(self.device)) class EEGFileWatcher: """ Watches a CSV file for new data and yields new lines as they appear. """ def __init__(self, file_path: str, poll_interval: float = 0.1): """ Initialize the file watcher. Args: file_path: Path to the CSV file to watch poll_interval: How often to check for new data (in seconds) """ self.file_path = Path(file_path) self.poll_interval = poll_interval self.last_position = 0 self.running = False self.thread = None self.queue = queue.Queue() self.header = None def start(self): """Start watching the file in a background thread.""" if self.running: return self.running = True self.thread = threading.Thread(target=self._watch_file, daemon=True) self.thread.start() def stop(self): """Stop watching the file.""" self.running = False if self.thread: self.thread.join(timeout=1.0) def _watch_file(self): """Background thread that watches the file for changes.""" # Wait for the file to exist while self.running and not self.file_path.exists(): logger.info(f"Waiting for file {self.file_path} to exist...") time.sleep(self.poll_interval) logger.info(f"File {self.file_path} found, starting to watch") # Keep track of the file position self.last_position = 0 # Read header first try: with open(self.file_path, 'r') as f: self.header = f.readline().strip() self.last_position = f.tell() self.queue.put(self.header) except Exception as e: logger.error(f"Error reading header: {e}") while self.running: try: # Check if the file has grown current_size = self.file_path.stat().st_size if current_size > self.last_position: # Read new data with open(self.file_path, 'r') as f: f.seek(self.last_position) new_data = f.read() self.last_position = f.tell() # Process new lines (excluding partial lines) lines = new_data.split('\n') if not new_data.endswith('\n'): # The last line might be incomplete, so we'll read it again next time self.last_position -= len(lines[-1]) lines = lines[:-1] # Add complete lines to the queue for line in lines: if line.strip(): # Skip empty lines self.queue.put(line) except Exception as e: logger.error(f"Error watching file: {e}") time.sleep(self.poll_interval) def get_new_lines(self, timeout: Optional[float] = None) -> List[str]: """ Get any new lines that have been read since the last call. Args: timeout: How long to wait for new data (in seconds). None means don't wait. Returns: List of new lines (might be empty if no new data) """ lines = [] try: # Get the first line (with timeout) line = self.queue.get(timeout=timeout) lines.append(line) # Get any remaining lines (without waiting) while True: try: line = self.queue.get_nowait() lines.append(line) except queue.Empty: break except queue.Empty: pass return lines class SlidingWindowProcessor: """ Processes data using a sliding window approach. """ def __init__( self, window_size: int, stride: int, num_channels: int, channel_means: List[float], channel_stds: List[float], normalize: bool = True ): """ Initialize the sliding window processor. Args: window_size: Number of data points in each window stride: Number of data points to advance between windows num_channels: Number of data channels channel_means: Mean value for each channel (for normalization) channel_stds: Standard deviation for each channel (for normalization) normalize: Whether to normalize the data """ self.window_size = window_size self.stride = stride self.num_channels = num_channels self.channel_means = np.array(channel_means) self.channel_stds = np.array(channel_stds) self.normalize = normalize # Buffer to hold data points self.buffer = [] # Current position in the buffer self.current_pos = 0 def add_data(self, data_points: List[Dict[str, Union[str, float]]]): """ Add new data points to the buffer. Args: data_points: List of data points. Each point should be a dictionary with 'timestamp' and channel values. """ self.buffer.extend(data_points) def get_windows(self) -> Generator[Tuple[List[str], np.ndarray], None, None]: """ Generate windows from the buffered data using the sliding window approach. Yields: Tuple of (timestamps, data array) for each window data array shape: [num_channels, window_size] """ while self.current_pos + self.window_size <= len(self.buffer): # Extract window window = self.buffer[self.current_pos:self.current_pos + self.window_size] # Extract timestamps timestamps = [point['timestamp'] for point in window] # Extract data data = np.zeros((self.num_channels, self.window_size), dtype=np.float32) for i, point in enumerate(window): for c in range(self.num_channels): channel_key = f'Channel{c+1}' if channel_key in point: data[c, i] = point[channel_key] # Normalize if requested if self.normalize: for c in range(self.num_channels): if self.channel_stds[c] > 0: data[c] = (data[c] - self.channel_means[c]) / self.channel_stds[c] yield timestamps, data # Advance by stride self.current_pos += self.stride # Remove processed data points that are no longer needed if self.current_pos > 0: # Keep the last (window_size - stride) points for the next window keep_from = max(0, self.current_pos - (self.window_size - self.stride)) self.buffer = self.buffer[keep_from:] self.current_pos = max(0, self.current_pos - keep_from) class EEGEmbeddingStream: """ Stream of EEG embeddings from a live CSV file. """ def __init__( self, file_path: str, model_path: str, window_size: int = 256, stride: int = 64, normalizer_params: Dict[str, List[float]] = None, poll_interval: float = 0.1, batch_size: int = 32, normalize: bool = True, device: str = None, start_from_timestamp: str = None, force_sequence_length: int = None # New parameter ): """ Initialize the EEG embedding stream. Args: file_path: Path to the CSV file to watch model_path: Path to the trained model checkpoint window_size: Number of data points in each window stride: Number of data points to advance between windows normalizer_params: Dictionary with 'means' and 'stds' for each channel If None, default values will be used poll_interval: How often to check for new data (in seconds) batch_size: How many windows to encode at once normalize: Whether to normalize the data device: Device to use for encoding ('cuda' or 'cpu') start_from_timestamp: Only process data from this timestamp onwards force_sequence_length: Force the model to use this sequence length (to match training) """ self.file_path = file_path self.poll_interval = poll_interval self.window_size = window_size self.stride = stride self.normalize = normalize self.batch_size = batch_size self.start_from_timestamp = start_from_timestamp # Set default normalizer parameters if not provided if normalizer_params is None: self.channel_means = [-70446.6562, -51197.2070, -42351.2812, -32628.9004, -58139.0547, -56271.2852, -48508.2305, -57654.8711, -69949.6484, -49663.8398, -43010.7070, -30252.7207, -56295.6250, -56075.9375, -48470.3086, -56338.5820] self.channel_stds = [76037.4453, 56048.1445, 71950.6328, 60051.6523, 64877.7422, 59371.3203, 56742.6055, 62344.4805, 75861.9141, 55614.6055, 70795.6719, 59312.4180, 64780.2109, 60292.6992, 56598.4609, 61472.3633] else: self.channel_means = normalizer_params['means'] self.channel_stds = normalizer_params['stds'] # Determine the number of channels self.num_channels = len(self.channel_means) # Initialize components self.file_watcher = EEGFileWatcher(file_path, poll_interval) self.window_processor = SlidingWindowProcessor( window_size, stride, self.num_channels, self.channel_means, self.channel_stds, normalize ) # Set the device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) # Load the encoder with the forced sequence length self.encoder = EncoderExtractor(model_path, self.device, force_sequence_length) # CSV header self.header = None # Running flag self.running = False # Statistics self.stats = { 'windows_processed': 0, 'start_time': None, 'last_timestamp': None } def start(self): """Start the embedding stream.""" if self.running: return self.running = True self.stats['start_time'] = time.time() self.file_watcher.start() def stop(self): """Stop the embedding stream.""" self.running = False self.file_watcher.stop() if self.stats['start_time'] is not None: elapsed = time.time() - self.stats['start_time'] windows_processed = self.stats['windows_processed'] if windows_processed > 0 and elapsed > 0: rate = windows_processed / elapsed logger.info(f"Processed {windows_processed} windows in {elapsed:.2f}s ({rate:.2f} windows/s)") def _parse_csv_line(self, line: str) -> Dict[str, Union[str, float]]: """ Parse a CSV line into a data point. Args: line: CSV line Returns: Dictionary with timestamp and channel values, or None if header """ if not self.header: # First line is the header self.header = line.split(',') logger.info(f"CSV header: {self.header}") return None values = line.split(',') if len(values) != len(self.header): logger.warning(f"Line has wrong number of values: {line}") return None data_point = {} for i, column in enumerate(self.header): if i == 0: # Timestamp column data_point['timestamp'] = values[i] # Skip if before start_from_timestamp if self.start_from_timestamp and values[i] < self.start_from_timestamp: return None else: # Channel column try: data_point[column] = float(values[i]) except ValueError: logger.warning(f"Could not parse value {values[i]} as float for column {column}") data_point[column] = 0.0 return data_point def get_embeddings(self, timeout: Optional[float] = None) -> Generator[Dict[str, Any], None, None]: """ Get embeddings for new data. Args: timeout: How long to wait for new data (in seconds). None means don't wait. Yields: Dictionary with window information and embedding """ if not self.running: self.start() # Get new lines from the file new_lines = self.file_watcher.get_new_lines(timeout) if not new_lines: return # Parse CSV lines data_points = [] for line in new_lines: data_point = self._parse_csv_line(line) if data_point: data_points.append(data_point) self.stats['last_timestamp'] = data_point['timestamp'] if not data_points: return # Add to the window processor self.window_processor.add_data(data_points) # Get windows and batch them for embedding windows = list(self.window_processor.get_windows()) if not windows: return for batch_start in range(0, len(windows), self.batch_size): batch_end = min(batch_start + self.batch_size, len(windows)) batch = windows[batch_start:batch_end] # Extract timestamps and data batch_timestamps = [window[0] for window in batch] batch_data = [window[1] for window in batch] # Convert to tensors batch_tensor = torch.tensor(np.array(batch_data), dtype=torch.float32) # Generate embeddings embeddings = self.encoder.embed(batch_tensor) # Convert to numpy and yield embeddings_np = embeddings.cpu().numpy() for i in range(len(batch)): self.stats['windows_processed'] += 1 yield { 'start_timestamp': batch_timestamps[i][0], 'end_timestamp': batch_timestamps[i][-1], 'embedding': embeddings_np[i], 'window_index': self.stats['windows_processed'] - 1 } def get_streaming_embeddings(self, callback: Optional[Callable[[Dict[str, Any]], None]] = None) -> Generator[Dict[str, Any], None, None]: """ Continuously generate embeddings and call the callback function with each one. Args: callback: Function to call with each embedding. If None, embeddings are yielded. Yields: If no callback is provided, yields dictionaries with window information and embedding """ self.start() try: while self.running: any_embeddings = False for embedding in self.get_embeddings(timeout=self.poll_interval): any_embeddings = True if callback: callback(embedding) else: yield embedding if not any_embeddings: # No new embeddings, just wait a bit time.sleep(self.poll_interval) finally: self.stop() def get_stats(self) -> Dict[str, Any]: """ Get statistics about the streaming process. Returns: Dictionary with statistics """ stats = dict(self.stats) if stats['start_time'] is not None: stats['elapsed'] = time.time() - stats['start_time'] if stats['windows_processed'] > 0 and stats['elapsed'] > 0: stats['windows_per_second'] = stats['windows_processed'] / stats['elapsed'] return stats # Example usage def example(): def handle_embedding(embedding): """Callback function to handle new embeddings.""" start_time = embedding['start_timestamp'] end_time = embedding['end_timestamp'] embedding_data = embedding['embedding'] print(f"Got embedding for window from {start_time} to {end_time}") print(f"Embedding shape: {embedding_data.shape}") print(f"First few values: {embedding_data.flatten()[:5]}") # Create the embedding stream stream = EEGEmbeddingStream( file_path="eeg_data.csv", model_path="models/eeg_autoencoder.pth", window_size=256, # Number of data points in each window stride=128, # How much to advance between windows poll_interval=0.5 # Check for new data every 0.5 seconds ) print("Starting embedding stream...") print("Press Ctrl+C to stop") try: # Method 1: Using callback stream.get_streaming_embeddings(callback=handle_embedding) # Method 2: Using generator # for embedding in stream.get_streaming_embeddings(): # handle_embedding(embedding) except KeyboardInterrupt: print("\nStopping...") finally: stream.stop() print("Stopped") if __name__ == "__main__": example()