| import glob
|
| import json
|
| import os
|
| from typing import Tuple, List, Dict, Any
|
|
|
| import h5py
|
| import numpy as np
|
| import scipy.signal as signal
|
| from joblib import Parallel, delayed
|
| from scipy.signal import iirnotch
|
| from tqdm.auto import tqdm
|
|
|
| def sequence_to_seconds(seq_len: int, fs: float) -> float:
|
| """Converts a sequence length in samples to time in seconds.
|
|
|
| Args:
|
| seq_len (int): The number of samples in the sequence.
|
| fs (float): The sampling frequency in Hz.
|
|
|
| Returns:
|
| float: The duration of the sequence in seconds.
|
| """
|
| return seq_len / fs
|
|
|
|
|
| tfs, n_ch = 200.0, 8
|
|
|
|
|
| gesture_map = {
|
| "noGesture": 0,
|
| "waveIn": 1,
|
| "waveOut": 2,
|
| "pinch": 3,
|
| "open": 4,
|
| "fist": 5,
|
| "notProvided": 6,
|
| }
|
|
|
|
|
| def bandpass_filter_emg(
|
| emg: np.ndarray,
|
| low: float = 20.0,
|
| high: float = 90.0,
|
| fs: float = tfs,
|
| order: int = 4
|
| ) -> np.ndarray:
|
| """Applies a Butterworth bandpass filter to the EMG signal.
|
|
|
| Args:
|
| emg (np.ndarray): The input signal array of shape (n_ch, T).
|
| low (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
|
| high (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
|
| fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
|
| order (int, optional): The order of the filter. Defaults to 4.
|
|
|
| Returns:
|
| np.ndarray: The filtered signal array.
|
| """
|
| nyq = 0.5 * fs
|
| b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
|
| return signal.filtfilt(b, a, emg, axis=1)
|
|
|
|
|
| def notch_filter_emg(
|
| emg: np.ndarray,
|
| notch: float = 50.0,
|
| Q: float = 30.0,
|
| fs: float = tfs
|
| ) -> np.ndarray:
|
| """Applies a notch filter to remove power line interference.
|
|
|
| Args:
|
| emg (np.ndarray): The input signal array of shape (n_ch, T).
|
| notch (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
|
| Q (float, optional): The quality factor. Defaults to 30.0.
|
| fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
|
|
|
| Returns:
|
| np.ndarray: The filtered signal array.
|
| """
|
| w0 = notch / (0.5 * fs)
|
| b, a = iirnotch(w0, Q)
|
| return signal.filtfilt(b, a, emg, axis=1)
|
|
|
|
|
| def zscore_per_channel(emg: np.ndarray) -> np.ndarray:
|
| """Normalizes the EMG signal using Z-score (per channel).
|
|
|
| Args:
|
| emg (np.ndarray): The input EMG signal of shape (n_ch, T).
|
|
|
| Returns:
|
| np.ndarray: The normalized EMG signal.
|
| """
|
| mean = emg.mean(axis=1, keepdims=True)
|
| std = emg.std(axis=1, ddof=1, keepdims=True)
|
| std[std == 0] = 1.0
|
| return (emg - mean) / std
|
|
|
|
|
| def adjust_length(x: np.ndarray, max_len: int) -> np.ndarray:
|
| """Standardizes the temporal length of the signal by clipping or zero-padding.
|
|
|
| Args:
|
| x (np.ndarray): The input signal of shape (n_ch, T).
|
| max_len (int): The target length in samples.
|
|
|
| Returns:
|
| np.ndarray: The standardized length signal of shape (n_ch, max_len).
|
| """
|
| n_ch, seq_len = x.shape
|
| if seq_len >= max_len:
|
| return x[:, :max_len]
|
| pad = np.zeros((n_ch, max_len - seq_len), dtype=x.dtype)
|
| return np.concatenate([x, pad], axis=1)
|
|
|
|
|
| def extract_emg_signal(sample: Dict[str, Any], seq_len: int) -> Tuple[np.ndarray, int]:
|
| """Extracts, filters, and normalizes EMG data from a JSON sample.
|
|
|
| Args:
|
| sample (Dict[str, Any]): A single sample dictionary from the EPN612 JSON.
|
| seq_len (int): Target temporal length.
|
|
|
| Returns:
|
| Tuple[np.ndarray, int]: A tuple containing:
|
| - The preprocessed EMG signal (n_ch, seq_len).
|
| - The gesture label ID.
|
| """
|
| emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
|
| emg = bandpass_filter_emg(emg, 20.0, 90.0)
|
| emg = notch_filter_emg(emg, 50.0, 30.0)
|
| emg = zscore_per_channel(emg)
|
| emg = adjust_length(emg, seq_len)
|
| label = gesture_map.get(sample.get("gestureName", "notProvided"), 6)
|
| return emg, label
|
|
|
|
|
| def process_user_training(
|
| path: str,
|
| seq_len: int
|
| ) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| """Processes a user's training JSON file for the training and validation splits.
|
|
|
| Args:
|
| path (str): Path to the user JSON file.
|
| seq_len (int): Target temporal length for segmentation.
|
|
|
| Returns:
|
| Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| (train_X, train_y, val_X, val_y) lists.
|
| """
|
| train_X, train_y, val_X, val_y = [], [], [], []
|
| with open(path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| for sample in data.get("trainingSamples", {}).values():
|
| emg, lbl = extract_emg_signal(sample, seq_len)
|
| if lbl != 6:
|
| train_X.append(emg)
|
| train_y.append(lbl)
|
| for sample in data.get("testingSamples", {}).values():
|
| emg, lbl = extract_emg_signal(sample, seq_len)
|
| if lbl != 10:
|
| pass
|
|
|
| if lbl != 6:
|
| val_X.append(emg)
|
| val_y.append(lbl)
|
| return train_X, train_y, val_X, val_y
|
|
|
|
|
| def process_user_testing(
|
| path: str,
|
| seq_len: int
|
| ) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| """Processes a user's testing JSON file for the fine-tuning and test splits.
|
|
|
| Args:
|
| path (str): Path to the user JSON file.
|
| seq_len (int): Target temporal length for segmentation.
|
|
|
| Returns:
|
| Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| (tune_X, tune_y, test_X, test_y) lists.
|
| """
|
| train_X, train_y, test_X, test_y = [], [], [], []
|
| with open(path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| buckets = {g: [] for g in gesture_map}
|
| for sample in data.get("trainingSamples", {}).values():
|
| buckets.setdefault(sample.get("gestureName", "notProvided"), []).append(sample)
|
| for samples in buckets.values():
|
| for i, sample in enumerate(samples):
|
| emg, lbl = extract_emg_signal(sample, seq_len)
|
| if lbl == 6:
|
| continue
|
| if i < 10:
|
| train_X.append(emg)
|
| train_y.append(lbl)
|
| else:
|
| test_X.append(emg)
|
| test_y.append(lbl)
|
| return train_X, train_y, test_X, test_y
|
|
|
|
|
| def save_h5(path: str, data: List[np.ndarray], labels: List[int]) -> None:
|
| """Saves the processed EMG data and labels to an HDF5 file.
|
|
|
| Args:
|
| path (str): Output file path.
|
| data (List[np.ndarray]): List of signal segments.
|
| labels (List[int]): List of categorical labels.
|
| """
|
| with h5py.File(path, "w") as f:
|
| f.create_dataset("data", data=np.asarray(data, np.float32))
|
| f.create_dataset("label", data=np.asarray(labels, np.int64))
|
|
|
|
|
| def main():
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--download_data", action="store_true")
|
| parser.add_argument("--data_dir", type=str, required=True)
|
| parser.add_argument("--source_training", required=True)
|
| parser.add_argument("--source_testing", required=True)
|
| parser.add_argument("--dest_dir", required=True)
|
| parser.add_argument(
|
| "--seq_len", type=int, help="Size of the window in samples for segmentation."
|
| )
|
| parser.add_argument("--n_jobs", type=int, default=-1)
|
| args = parser.parse_args()
|
| data_dir = args.data_dir
|
| os.makedirs(args.dest_dir, exist_ok=True)
|
|
|
|
|
| if args.download_data:
|
|
|
| url = "https://zenodo.org/records/4421500/files/EMG-EPN612%20Dataset.zip?download=1"
|
| os.system(f"wget -O {data_dir}/EMG-EPN612_Dataset.zip {url}")
|
| os.system(f"unzip -o {data_dir}/EMG-EPN612_Dataset.zip -d {data_dir}")
|
|
|
| os.system(rf"mv {data_dir}/EMG-EPN612\ Dataset/* {data_dir}/")
|
| os.system(f"rmdir {data_dir}/EMG-EPN612_Dataset")
|
|
|
| os.system(f"rm {data_dir}/EMG-EPN612_Dataset.zip")
|
| print(f"Downloaded and unzipped dataset\n{data_dir}/EMG-EPN612_Dataset.zip")
|
|
|
| seq_len = args.seq_len
|
|
|
| window_seconds = sequence_to_seconds(seq_len, tfs)
|
| print(f"Window size: {seq_len} samples ({window_seconds:.2f} seconds)")
|
|
|
| train_X, train_y, val_X, val_y, test_X, test_y = [], [], [], [], [], []
|
|
|
| paths = glob.glob(os.path.join(args.source_training, "user*", "user*.json"))
|
|
|
|
|
| results = Parallel(n_jobs=args.n_jobs)(
|
| delayed(process_user_training)(p, seq_len)
|
| for p in tqdm(paths, desc="Training files")
|
| )
|
| for tX, ty, vX, vy in results:
|
| train_X.extend(tX)
|
| train_y.extend(ty)
|
| val_X.extend(vX)
|
| val_y.extend(vy)
|
|
|
|
|
| test_results = Parallel(n_jobs=args.n_jobs)(
|
| delayed(process_user_testing)(p, seq_len)
|
| for p in tqdm(
|
| glob.glob(os.path.join(args.source_testing, "user*", "user*.json")),
|
| desc="Testing files",
|
| )
|
| )
|
| for tX, ty, teX, tey in test_results:
|
| train_X.extend(tX)
|
| train_y.extend(ty)
|
| test_X.extend(teX)
|
| test_y.extend(tey)
|
|
|
|
|
| save_h5(os.path.join(args.dest_dir, "train.h5"), train_X, train_y)
|
| save_h5(os.path.join(args.dest_dir, "val.h5"), val_X, val_y)
|
| save_h5(os.path.join(args.dest_dir, "test.h5"), test_X, test_y)
|
|
|
|
|
| for split, X, y in [
|
| ("Train", train_X, train_y),
|
| ("Val", val_X, val_y),
|
| ("Test", test_X, test_y),
|
| ]:
|
| arr = np.array(y)
|
| uniq, cnt = np.unique(arr, return_counts=True)
|
| uniq = [i.item() for i in uniq]
|
| cnt = [i.item() for i in cnt]
|
| print(f"{split} → total={len(y)}, classes={{}}".format(dict(zip(uniq, cnt))))
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|