| import os |
| import time |
| import csv |
| import numpy as np |
| import torch |
| import threading |
| import queue |
| from datetime import datetime |
| from pathlib import Path |
| from typing import List, Dict, Tuple, Optional, Callable, Generator, Union, Any |
| import logging |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger("EEGStream") |
|
|
|
|
| class EncoderExtractor: |
| def __init__(self, model_path, device=None, force_sequence_length=None): |
| if device is None: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = device |
|
|
| self.force_sequence_length = force_sequence_length |
| logger.info(f"Loading traced encoder from {model_path} to {self.device}") |
| self.model = torch.jit.load(model_path, map_location=self.device) |
| self.model.eval() |
|
|
| dummy = torch.randn(1, 16, force_sequence_length or 624, device=self.device) |
| with torch.no_grad(): |
| self._embedding_size = self.model(dummy).shape[1] |
| logger.info(f"Embedding size: {self._embedding_size}") |
|
|
| def get_embedding_size(self): |
| return self._embedding_size |
|
|
| def embed(self, data): |
| if self.force_sequence_length and data.shape[2] != self.force_sequence_length: |
| data = torch.nn.functional.interpolate( |
| data, size=self.force_sequence_length, mode='linear', align_corners=False |
| ) |
| with torch.no_grad(): |
| return self.model(data.to(self.device)) |
|
|
| class EEGFileWatcher: |
| """ |
| Watches a CSV file for new data and yields new lines as they appear. |
| """ |
| def __init__(self, file_path: str, poll_interval: float = 0.1): |
| """ |
| Initialize the file watcher. |
| |
| Args: |
| file_path: Path to the CSV file to watch |
| poll_interval: How often to check for new data (in seconds) |
| """ |
| self.file_path = Path(file_path) |
| self.poll_interval = poll_interval |
| self.last_position = 0 |
| self.running = False |
| self.thread = None |
| self.queue = queue.Queue() |
| self.header = None |
| |
| def start(self): |
| """Start watching the file in a background thread.""" |
| if self.running: |
| return |
| |
| self.running = True |
| self.thread = threading.Thread(target=self._watch_file, daemon=True) |
| self.thread.start() |
| |
| def stop(self): |
| """Stop watching the file.""" |
| self.running = False |
| if self.thread: |
| self.thread.join(timeout=1.0) |
| |
| def _watch_file(self): |
| """Background thread that watches the file for changes.""" |
| |
| while self.running and not self.file_path.exists(): |
| logger.info(f"Waiting for file {self.file_path} to exist...") |
| time.sleep(self.poll_interval) |
| |
| logger.info(f"File {self.file_path} found, starting to watch") |
| |
| |
| self.last_position = 0 |
| |
| |
| try: |
| with open(self.file_path, 'r') as f: |
| self.header = f.readline().strip() |
| self.last_position = f.tell() |
| self.queue.put(self.header) |
| except Exception as e: |
| logger.error(f"Error reading header: {e}") |
| |
| while self.running: |
| try: |
| |
| current_size = self.file_path.stat().st_size |
| if current_size > self.last_position: |
| |
| with open(self.file_path, 'r') as f: |
| f.seek(self.last_position) |
| new_data = f.read() |
| self.last_position = f.tell() |
| |
| |
| lines = new_data.split('\n') |
| if not new_data.endswith('\n'): |
| |
| self.last_position -= len(lines[-1]) |
| lines = lines[:-1] |
| |
| |
| for line in lines: |
| if line.strip(): |
| self.queue.put(line) |
| except Exception as e: |
| logger.error(f"Error watching file: {e}") |
| |
| time.sleep(self.poll_interval) |
| |
| def get_new_lines(self, timeout: Optional[float] = None) -> List[str]: |
| """ |
| Get any new lines that have been read since the last call. |
| |
| Args: |
| timeout: How long to wait for new data (in seconds). None means don't wait. |
| |
| Returns: |
| List of new lines (might be empty if no new data) |
| """ |
| lines = [] |
| try: |
| |
| line = self.queue.get(timeout=timeout) |
| lines.append(line) |
| |
| |
| while True: |
| try: |
| line = self.queue.get_nowait() |
| lines.append(line) |
| except queue.Empty: |
| break |
| except queue.Empty: |
| pass |
| |
| return lines |
|
|
|
|
| class SlidingWindowProcessor: |
| """ |
| Processes data using a sliding window approach. |
| """ |
| def __init__( |
| self, |
| window_size: int, |
| stride: int, |
| num_channels: int, |
| channel_means: List[float], |
| channel_stds: List[float], |
| normalize: bool = True |
| ): |
| """ |
| Initialize the sliding window processor. |
| |
| Args: |
| window_size: Number of data points in each window |
| stride: Number of data points to advance between windows |
| num_channels: Number of data channels |
| channel_means: Mean value for each channel (for normalization) |
| channel_stds: Standard deviation for each channel (for normalization) |
| normalize: Whether to normalize the data |
| """ |
| self.window_size = window_size |
| self.stride = stride |
| self.num_channels = num_channels |
| self.channel_means = np.array(channel_means) |
| self.channel_stds = np.array(channel_stds) |
| self.normalize = normalize |
| |
| |
| self.buffer = [] |
| |
| |
| self.current_pos = 0 |
| |
| def add_data(self, data_points: List[Dict[str, Union[str, float]]]): |
| """ |
| Add new data points to the buffer. |
| |
| Args: |
| data_points: List of data points. Each point should be a dictionary with |
| 'timestamp' and channel values. |
| """ |
| self.buffer.extend(data_points) |
| |
| def get_windows(self) -> Generator[Tuple[List[str], np.ndarray], None, None]: |
| """ |
| Generate windows from the buffered data using the sliding window approach. |
| |
| Yields: |
| Tuple of (timestamps, data array) for each window |
| data array shape: [num_channels, window_size] |
| """ |
| while self.current_pos + self.window_size <= len(self.buffer): |
| |
| window = self.buffer[self.current_pos:self.current_pos + self.window_size] |
| |
| |
| timestamps = [point['timestamp'] for point in window] |
| |
| |
| data = np.zeros((self.num_channels, self.window_size), dtype=np.float32) |
| for i, point in enumerate(window): |
| for c in range(self.num_channels): |
| channel_key = f'Channel{c+1}' |
| if channel_key in point: |
| data[c, i] = point[channel_key] |
| |
| |
| if self.normalize: |
| for c in range(self.num_channels): |
| if self.channel_stds[c] > 0: |
| data[c] = (data[c] - self.channel_means[c]) / self.channel_stds[c] |
| |
| yield timestamps, data |
| |
| |
| self.current_pos += self.stride |
| |
| |
| if self.current_pos > 0: |
| |
| keep_from = max(0, self.current_pos - (self.window_size - self.stride)) |
| self.buffer = self.buffer[keep_from:] |
| self.current_pos = max(0, self.current_pos - keep_from) |
|
|
|
|
| class EEGEmbeddingStream: |
| """ |
| Stream of EEG embeddings from a live CSV file. |
| """ |
| def __init__( |
| self, |
| file_path: str, |
| model_path: str, |
| window_size: int = 256, |
| stride: int = 64, |
| normalizer_params: Dict[str, List[float]] = None, |
| poll_interval: float = 0.1, |
| batch_size: int = 32, |
| normalize: bool = True, |
| device: str = None, |
| start_from_timestamp: str = None, |
| force_sequence_length: int = None |
| ): |
| """ |
| Initialize the EEG embedding stream. |
| |
| Args: |
| file_path: Path to the CSV file to watch |
| model_path: Path to the trained model checkpoint |
| window_size: Number of data points in each window |
| stride: Number of data points to advance between windows |
| normalizer_params: Dictionary with 'means' and 'stds' for each channel |
| If None, default values will be used |
| poll_interval: How often to check for new data (in seconds) |
| batch_size: How many windows to encode at once |
| normalize: Whether to normalize the data |
| device: Device to use for encoding ('cuda' or 'cpu') |
| start_from_timestamp: Only process data from this timestamp onwards |
| force_sequence_length: Force the model to use this sequence length (to match training) |
| """ |
| self.file_path = file_path |
| self.poll_interval = poll_interval |
| self.window_size = window_size |
| self.stride = stride |
| self.normalize = normalize |
| self.batch_size = batch_size |
| self.start_from_timestamp = start_from_timestamp |
|
|
| |
| if normalizer_params is None: |
| self.channel_means = [-70446.6562, -51197.2070, -42351.2812, -32628.9004, -58139.0547, |
| -56271.2852, -48508.2305, -57654.8711, -69949.6484, -49663.8398, |
| -43010.7070, -30252.7207, -56295.6250, -56075.9375, -48470.3086, |
| -56338.5820] |
| self.channel_stds = [76037.4453, 56048.1445, 71950.6328, 60051.6523, 64877.7422, |
| 59371.3203, 56742.6055, 62344.4805, 75861.9141, 55614.6055, |
| 70795.6719, 59312.4180, 64780.2109, 60292.6992, 56598.4609, |
| 61472.3633] |
| else: |
| self.channel_means = normalizer_params['means'] |
| self.channel_stds = normalizer_params['stds'] |
|
|
| |
| self.num_channels = len(self.channel_means) |
|
|
| |
| self.file_watcher = EEGFileWatcher(file_path, poll_interval) |
| self.window_processor = SlidingWindowProcessor( |
| window_size, stride, self.num_channels, |
| self.channel_means, self.channel_stds, normalize |
| ) |
|
|
| |
| if device is None: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = torch.device(device) |
|
|
| |
| self.encoder = EncoderExtractor(model_path, self.device, force_sequence_length) |
|
|
| |
| self.header = None |
|
|
| |
| self.running = False |
|
|
| |
| self.stats = { |
| 'windows_processed': 0, |
| 'start_time': None, |
| 'last_timestamp': None |
| } |
|
|
| def start(self): |
| """Start the embedding stream.""" |
| if self.running: |
| return |
|
|
| self.running = True |
| self.stats['start_time'] = time.time() |
| self.file_watcher.start() |
|
|
| def stop(self): |
| """Stop the embedding stream.""" |
| self.running = False |
| self.file_watcher.stop() |
|
|
| if self.stats['start_time'] is not None: |
| elapsed = time.time() - self.stats['start_time'] |
| windows_processed = self.stats['windows_processed'] |
| if windows_processed > 0 and elapsed > 0: |
| rate = windows_processed / elapsed |
| logger.info(f"Processed {windows_processed} windows in {elapsed:.2f}s ({rate:.2f} windows/s)") |
|
|
| def _parse_csv_line(self, line: str) -> Dict[str, Union[str, float]]: |
| """ |
| Parse a CSV line into a data point. |
| |
| Args: |
| line: CSV line |
| |
| Returns: |
| Dictionary with timestamp and channel values, or None if header |
| """ |
| if not self.header: |
| |
| self.header = line.split(',') |
| logger.info(f"CSV header: {self.header}") |
| return None |
|
|
| values = line.split(',') |
| if len(values) != len(self.header): |
| logger.warning(f"Line has wrong number of values: {line}") |
| return None |
|
|
| data_point = {} |
| for i, column in enumerate(self.header): |
| if i == 0: |
| |
| data_point['timestamp'] = values[i] |
|
|
| |
| if self.start_from_timestamp and values[i] < self.start_from_timestamp: |
| return None |
| else: |
| |
| try: |
| data_point[column] = float(values[i]) |
| except ValueError: |
| logger.warning(f"Could not parse value {values[i]} as float for column {column}") |
| data_point[column] = 0.0 |
|
|
| return data_point |
|
|
| def get_embeddings(self, timeout: Optional[float] = None) -> Generator[Dict[str, Any], None, None]: |
| """ |
| Get embeddings for new data. |
| |
| Args: |
| timeout: How long to wait for new data (in seconds). None means don't wait. |
| |
| Yields: |
| Dictionary with window information and embedding |
| """ |
| if not self.running: |
| self.start() |
|
|
| |
| new_lines = self.file_watcher.get_new_lines(timeout) |
| if not new_lines: |
| return |
|
|
| |
| data_points = [] |
| for line in new_lines: |
| data_point = self._parse_csv_line(line) |
| if data_point: |
| data_points.append(data_point) |
| self.stats['last_timestamp'] = data_point['timestamp'] |
|
|
| if not data_points: |
| return |
|
|
| |
| self.window_processor.add_data(data_points) |
|
|
| |
| windows = list(self.window_processor.get_windows()) |
| if not windows: |
| return |
|
|
| for batch_start in range(0, len(windows), self.batch_size): |
| batch_end = min(batch_start + self.batch_size, len(windows)) |
| batch = windows[batch_start:batch_end] |
|
|
| |
| batch_timestamps = [window[0] for window in batch] |
| batch_data = [window[1] for window in batch] |
|
|
| |
| batch_tensor = torch.tensor(np.array(batch_data), dtype=torch.float32) |
|
|
| |
| embeddings = self.encoder.embed(batch_tensor) |
|
|
| |
| embeddings_np = embeddings.cpu().numpy() |
|
|
| for i in range(len(batch)): |
| self.stats['windows_processed'] += 1 |
| yield { |
| 'start_timestamp': batch_timestamps[i][0], |
| 'end_timestamp': batch_timestamps[i][-1], |
| 'embedding': embeddings_np[i], |
| 'window_index': self.stats['windows_processed'] - 1 |
| } |
|
|
| def get_streaming_embeddings(self, callback: Optional[Callable[[Dict[str, Any]], None]] = None) -> Generator[Dict[str, Any], None, None]: |
| """ |
| Continuously generate embeddings and call the callback function with each one. |
| |
| Args: |
| callback: Function to call with each embedding. If None, embeddings are yielded. |
| |
| Yields: |
| If no callback is provided, yields dictionaries with window information and embedding |
| """ |
| self.start() |
|
|
| try: |
| while self.running: |
| any_embeddings = False |
| for embedding in self.get_embeddings(timeout=self.poll_interval): |
| any_embeddings = True |
| if callback: |
| callback(embedding) |
| else: |
| yield embedding |
|
|
| if not any_embeddings: |
| |
| time.sleep(self.poll_interval) |
| finally: |
| self.stop() |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """ |
| Get statistics about the streaming process. |
| |
| Returns: |
| Dictionary with statistics |
| """ |
| stats = dict(self.stats) |
| if stats['start_time'] is not None: |
| stats['elapsed'] = time.time() - stats['start_time'] |
| if stats['windows_processed'] > 0 and stats['elapsed'] > 0: |
| stats['windows_per_second'] = stats['windows_processed'] / stats['elapsed'] |
| return stats |
|
|
| |
| def example(): |
| def handle_embedding(embedding): |
| """Callback function to handle new embeddings.""" |
| start_time = embedding['start_timestamp'] |
| end_time = embedding['end_timestamp'] |
| embedding_data = embedding['embedding'] |
| |
| print(f"Got embedding for window from {start_time} to {end_time}") |
| print(f"Embedding shape: {embedding_data.shape}") |
| print(f"First few values: {embedding_data.flatten()[:5]}") |
|
|
| |
| stream = EEGEmbeddingStream( |
| file_path="eeg_data.csv", |
| model_path="models/eeg_autoencoder.pth", |
| window_size=256, |
| stride=128, |
| poll_interval=0.5 |
| ) |
|
|
| print("Starting embedding stream...") |
| print("Press Ctrl+C to stop") |
|
|
| try: |
| |
| stream.get_streaming_embeddings(callback=handle_embedding) |
| |
| |
| |
| |
| except KeyboardInterrupt: |
| print("\nStopping...") |
| finally: |
| stream.stop() |
| print("Stopped") |
|
|
|
|
| if __name__ == "__main__": |
| example() |
|
|