| import os |
| import gc |
| from pathlib import Path |
| from typing import Tuple, List, Optional, Union, Dict, Any |
|
|
| import h5py |
| import numpy as np |
| import pandas as pd |
| import scipy.signal as signal |
| from joblib import Parallel, delayed |
| from scipy.signal import iirnotch |
| from tqdm import tqdm |
|
|
| def sequence_to_seconds(seq_len: int, fs: float) -> float: |
| """Converts a sequence length in samples to time in seconds. |
| |
| Args: |
| seq_len (int): The number of samples in the sequence. |
| fs (float): The sampling frequency in Hz. |
| |
| Returns: |
| float: The duration of the sequence in seconds. |
| """ |
| return seq_len / fs |
|
|
|
|
| def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 2000.0) -> np.ndarray: |
| """Applies a notch filter to every channel of the input data independently. |
| |
| Args: |
| data (np.ndarray): The input signal array of shape (T, D). |
| notch_freq (float, optional): The frequency to be removed in Hz. Defaults to 50.0. |
| Q (float, optional): The quality factor. Defaults to 30.0. |
| fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0. |
| |
| Returns: |
| np.ndarray: The filtered signal array. |
| """ |
| b, a = iirnotch(notch_freq, Q, fs) |
| out = np.zeros_like(data) |
| for ch in range(data.shape[1]): |
| out[:, ch] = signal.filtfilt(b, a, data[:, ch]) |
| return out |
|
|
|
|
| def bandpass_filter_emg( |
| emg: np.ndarray, |
| lowcut: float = 20.0, |
| highcut: float = 90.0, |
| fs: float = 2000.0, |
| order: int = 4 |
| ) -> np.ndarray: |
| """Applies a Butterworth bandpass filter to the EMG signal. |
| |
| Args: |
| emg (np.ndarray): The input signal array of shape (T, D). |
| lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0. |
| highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0. |
| fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0. |
| order (int, optional): The order of the filter. Defaults to 4. |
| |
| Returns: |
| np.ndarray: The filtered signal array. |
| """ |
| nyq = 0.5 * fs |
| low = lowcut / nyq |
| high = highcut / nyq |
| b, a = signal.butter(order, [low, high], btype="bandpass") |
| out = np.zeros_like(emg) |
| for c in range(emg.shape[1]): |
| out[:, c] = signal.filtfilt(b, a, emg[:, c]) |
| return out |
|
|
|
|
| def process_emg_features(emg: np.ndarray, window_size: int = 1000, stride: int = 500) -> np.ndarray: |
| """Segments raw EMG signals into overlapping windows. |
| |
| Args: |
| emg (np.ndarray): Raw EMG data of shape (T, n_ch). |
| window_size (int, optional): Number of samples per window. Defaults to 1000. |
| stride (int, optional): Number of samples to shift between windows. Defaults to 500. |
| |
| Returns: |
| np.ndarray: Segmented data of shape (N, window_size, n_ch). |
| """ |
| segs = [] |
| N = len(emg) |
| for start in range(0, N, stride): |
| end = start + window_size |
| if end > N: |
| continue |
| win = emg[start:end] |
| segs.append(win) |
| return np.array(segs) |
|
|
|
|
| def process_one_recording(file_path: str, fs: float = 2000.0, window_size: int = 1000, stride: int = 500) -> np.ndarray: |
| """Processes a single EMG2Pose recording file. |
| |
| Loads HDF5 timeseries, filters EMG, normalizes (Z-score), and segments. |
| |
| Args: |
| file_path (str): Absolute path to the .h5 recording file. |
| fs (float, optional): Sampling frequency in Hz. Defaults to 2000.0. |
| window_size (int, optional): Temporal window size in samples. Defaults to 1000. |
| stride (int, optional): Stride between windows in samples. Defaults to 500. |
| |
| Returns: |
| np.ndarray: Array of processed segments (N, window_size, n_ch). |
| """ |
| with h5py.File(file_path, "r") as f: |
| grp = f["emg2pose"] |
| data = grp["timeseries"] |
| emg = data["emg"][:].astype(np.float32) |
|
|
| |
| emg_filt = bandpass_filter_emg(emg, 20, 450, fs=fs) |
| emg_filt = notch_filter(emg_filt, 50, 30, fs=fs) |
|
|
| |
| mu = emg_filt.mean(axis=0) |
| sd = emg_filt.std(axis=0, ddof=1) |
| sd[sd == 0] = 1.0 |
| emg_z = (emg_filt - mu) / sd |
|
|
| |
| segs = process_emg_features(emg_z, window_size, stride) |
|
|
| return segs |
|
|
|
|
| def main(): |
| import argparse |
|
|
| args = argparse.ArgumentParser(description="Process EMG data from DB5.") |
| args.add_argument("--data_dir", type=str) |
| args.add_argument("--save_dir", type=str) |
| args.add_argument( |
| "--seq_len", type=int, help="Size of the window in samples for segmentation." |
| ) |
| args.add_argument( |
| "--stride", type=int, help="Step size between windows in samples for segmentation." |
| ) |
| args.add_argument( |
| "--subsample", type=float, default=1.0, help="Whether to subsample the data" |
| ) |
| args.add_argument( |
| "--n_jobs", |
| type=int, |
| default=-1, |
| help="Number of parallel jobs to run. -1 means using all available cores.", |
| ) |
| args.add_argument( |
| "--group_size", |
| type=int, |
| default=1000, |
| help="Number of samples per group in the output HDF5 file.", |
| ) |
| args.add_argument( |
| "--seed", type=int, default=42, help="Random seed for reproducibility." |
| ) |
| args = args.parse_args() |
|
|
| data_dir = args.data_dir |
| save_dir = args.save_dir |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| fs = 2000.0 |
| window_size, stride = args.seq_len, args.stride |
|
|
| window_seconds = sequence_to_seconds(window_size, fs) |
| print(f"Window size: {window_size} samples ({window_seconds:.2f} seconds)") |
|
|
| df = pd.read_csv(os.path.join(data_dir, "metadata.csv")) |
| if args.subsample < 1.0: |
| df = df.groupby("split", group_keys=False).sample( |
| frac=args.subsample, random_state=args.seed |
| ) |
| df = df.reset_index(drop=True) |
|
|
| splits = {} |
| for split, df_ in df.groupby("split"): |
| sessions = list(df_.filename) |
| splits[split] =[ |
| Path(data_dir).expanduser().joinpath(f"{session}.hdf5") |
| for session in sessions |
| ] |
|
|
| for split, files in splits.items(): |
| out_file = os.path.join(save_dir, f"{split}.h5") |
|
|
| |
| if os.path.exists(out_file): |
| os.remove(out_file) |
|
|
| print(f"Processing {split} split ({len(files)} files)...") |
|
|
| with h5py.File(out_file, "w") as h5f: |
| group_idx = 0 |
| with Parallel(n_jobs=args.n_jobs) as parallel: |
| with tqdm(total=len(files), desc=f"Processing & Saving {split}") as pbar: |
|
|
| |
| for i in range(0, len(files), args.group_size): |
| batch_files = files[i : i + args.group_size] |
|
|
| |
| results = parallel( |
| delayed(process_one_recording)(file_path, fs, window_size, stride) |
| for file_path in batch_files |
| ) |
|
|
| if results: |
| X_chunk = np.concatenate(results, axis=0) |
| X_chunk = X_chunk.transpose(0, 2, 1) |
| X_chunk = X_chunk.astype(np.float32) |
|
|
| |
| grp = h5f.create_group(f"data_group_{group_idx}") |
| grp.create_dataset("X", data=X_chunk) |
| group_idx += 1 |
|
|
| |
| del results |
| if 'X_chunk' in locals(): |
| del X_chunk |
| gc.collect() |
|
|
| pbar.update(len(batch_files)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |