TinyMyo / scripts /epn.py
MatteoFasulo's picture
refactor: update dataset processing commands and improve script structure
f7b4d24
import glob
import json
import os
from typing import Tuple, List, Dict, Any
import h5py
import numpy as np
import scipy.signal as signal
from joblib import Parallel, delayed
from scipy.signal import iirnotch
from tqdm.auto import tqdm
def sequence_to_seconds(seq_len: int, fs: float) -> float:
"""Converts a sequence length in samples to time in seconds.
Args:
seq_len (int): The number of samples in the sequence.
fs (float): The sampling frequency in Hz.
Returns:
float: The duration of the sequence in seconds.
"""
return seq_len / fs
# Sampling frequency and EMG channels
tfs, n_ch = 200.0, 8
# Gesture label mapping
gesture_map = {
"noGesture": 0,
"waveIn": 1,
"waveOut": 2,
"pinch": 3,
"open": 4,
"fist": 5,
"notProvided": 6,
}
def bandpass_filter_emg(
emg: np.ndarray,
low: float = 20.0,
high: float = 90.0,
fs: float = tfs,
order: int = 4
) -> np.ndarray:
"""Applies a Butterworth bandpass filter to the EMG signal.
Args:
emg (np.ndarray): The input signal array of shape (n_ch, T).
low (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
high (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
order (int, optional): The order of the filter. Defaults to 4.
Returns:
np.ndarray: The filtered signal array.
"""
nyq = 0.5 * fs
b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
return signal.filtfilt(b, a, emg, axis=1)
def notch_filter_emg(
emg: np.ndarray,
notch: float = 50.0,
Q: float = 30.0,
fs: float = tfs
) -> np.ndarray:
"""Applies a notch filter to remove power line interference.
Args:
emg (np.ndarray): The input signal array of shape (n_ch, T).
notch (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
Q (float, optional): The quality factor. Defaults to 30.0.
fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
Returns:
np.ndarray: The filtered signal array.
"""
w0 = notch / (0.5 * fs)
b, a = iirnotch(w0, Q)
return signal.filtfilt(b, a, emg, axis=1)
def zscore_per_channel(emg: np.ndarray) -> np.ndarray:
"""Normalizes the EMG signal using Z-score (per channel).
Args:
emg (np.ndarray): The input EMG signal of shape (n_ch, T).
Returns:
np.ndarray: The normalized EMG signal.
"""
mean = emg.mean(axis=1, keepdims=True)
std = emg.std(axis=1, ddof=1, keepdims=True)
std[std == 0] = 1.0
return (emg - mean) / std
def adjust_length(x: np.ndarray, max_len: int) -> np.ndarray:
"""Standardizes the temporal length of the signal by clipping or zero-padding.
Args:
x (np.ndarray): The input signal of shape (n_ch, T).
max_len (int): The target length in samples.
Returns:
np.ndarray: The standardized length signal of shape (n_ch, max_len).
"""
n_ch, seq_len = x.shape
if seq_len >= max_len:
return x[:, :max_len]
pad = np.zeros((n_ch, max_len - seq_len), dtype=x.dtype)
return np.concatenate([x, pad], axis=1)
def extract_emg_signal(sample: Dict[str, Any], seq_len: int) -> Tuple[np.ndarray, int]:
"""Extracts, filters, and normalizes EMG data from a JSON sample.
Args:
sample (Dict[str, Any]): A single sample dictionary from the EPN612 JSON.
seq_len (int): Target temporal length.
Returns:
Tuple[np.ndarray, int]: A tuple containing:
- The preprocessed EMG signal (n_ch, seq_len).
- The gesture label ID.
"""
emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
emg = bandpass_filter_emg(emg, 20.0, 90.0)
emg = notch_filter_emg(emg, 50.0, 30.0)
emg = zscore_per_channel(emg)
emg = adjust_length(emg, seq_len)
label = gesture_map.get(sample.get("gestureName", "notProvided"), 6)
return emg, label
def process_user_training(
path: str,
seq_len: int
) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
"""Processes a user's training JSON file for the training and validation splits.
Args:
path (str): Path to the user JSON file.
seq_len (int): Target temporal length for segmentation.
Returns:
Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
(train_X, train_y, val_X, val_y) lists.
"""
train_X, train_y, val_X, val_y = [], [], [], []
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
for sample in data.get("trainingSamples", {}).values():
emg, lbl = extract_emg_signal(sample, seq_len)
if lbl != 6:
train_X.append(emg)
train_y.append(lbl)
for sample in data.get("testingSamples", {}).values():
emg, lbl = extract_emg_signal(sample, seq_len)
if lbl != 10: # Assuming 10 was the intention or checking if not invalid
pass
# Note: checking lbl != 6 as in original
if lbl != 6:
val_X.append(emg)
val_y.append(lbl)
return train_X, train_y, val_X, val_y
def process_user_testing(
path: str,
seq_len: int
) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
"""Processes a user's testing JSON file for the fine-tuning and test splits.
Args:
path (str): Path to the user JSON file.
seq_len (int): Target temporal length for segmentation.
Returns:
Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
(tune_X, tune_y, test_X, test_y) lists.
"""
train_X, train_y, test_X, test_y = [], [], [], []
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
buckets = {g: [] for g in gesture_map}
for sample in data.get("trainingSamples", {}).values():
buckets.setdefault(sample.get("gestureName", "notProvided"), []).append(sample)
for samples in buckets.values():
for i, sample in enumerate(samples):
emg, lbl = extract_emg_signal(sample, seq_len)
if lbl == 6:
continue
if i < 10:
train_X.append(emg)
train_y.append(lbl)
else:
test_X.append(emg)
test_y.append(lbl)
return train_X, train_y, test_X, test_y
def save_h5(path: str, data: List[np.ndarray], labels: List[int]) -> None:
"""Saves the processed EMG data and labels to an HDF5 file.
Args:
path (str): Output file path.
data (List[np.ndarray]): List of signal segments.
labels (List[int]): List of categorical labels.
"""
with h5py.File(path, "w") as f:
f.create_dataset("data", data=np.asarray(data, np.float32))
f.create_dataset("label", data=np.asarray(labels, np.int64))
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--download_data", action="store_true")
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--source_training", required=True)
parser.add_argument("--source_testing", required=True)
parser.add_argument("--dest_dir", required=True)
parser.add_argument(
"--seq_len", type=int, help="Size of the window in samples for segmentation."
)
parser.add_argument("--n_jobs", type=int, default=-1)
args = parser.parse_args()
data_dir = args.data_dir
os.makedirs(args.dest_dir, exist_ok=True)
# download data if requested
if args.download_data:
# https://zenodo.org/records/4421500
url = "https://zenodo.org/records/4421500/files/EMG-EPN612%20Dataset.zip?download=1"
os.system(f"wget -O {data_dir}/EMG-EPN612_Dataset.zip {url}")
os.system(f"unzip -o {data_dir}/EMG-EPN612_Dataset.zip -d {data_dir}")
# move the contents one level up
os.system(rf"mv {data_dir}/EMG-EPN612\ Dataset/* {data_dir}/")
os.system(f"rmdir {data_dir}/EMG-EPN612_Dataset")
# clean up zip file
os.system(f"rm {data_dir}/EMG-EPN612_Dataset.zip")
print(f"Downloaded and unzipped dataset\n{data_dir}/EMG-EPN612_Dataset.zip")
seq_len = args.seq_len
window_seconds = sequence_to_seconds(seq_len, tfs)
print(f"Window size: {seq_len} samples ({window_seconds:.2f} seconds)")
train_X, train_y, val_X, val_y, test_X, test_y = [], [], [], [], [], []
paths = glob.glob(os.path.join(args.source_training, "user*", "user*.json"))
# Parallel process training JSONs
results = Parallel(n_jobs=args.n_jobs)(
delayed(process_user_training)(p, seq_len)
for p in tqdm(paths, desc="Training files")
)
for tX, ty, vX, vy in results:
train_X.extend(tX)
train_y.extend(ty)
val_X.extend(vX)
val_y.extend(vy)
# Parallel process testing JSONs
test_results = Parallel(n_jobs=args.n_jobs)(
delayed(process_user_testing)(p, seq_len)
for p in tqdm(
glob.glob(os.path.join(args.source_testing, "user*", "user*.json")),
desc="Testing files",
)
)
for tX, ty, teX, tey in test_results:
train_X.extend(tX)
train_y.extend(ty)
test_X.extend(teX)
test_y.extend(tey)
# Save datasets
save_h5(os.path.join(args.dest_dir, "train.h5"), train_X, train_y)
save_h5(os.path.join(args.dest_dir, "val.h5"), val_X, val_y)
save_h5(os.path.join(args.dest_dir, "test.h5"), test_X, test_y)
# Print distributions
for split, X, y in [
("Train", train_X, train_y),
("Val", val_X, val_y),
("Test", test_X, test_y),
]:
arr = np.array(y)
uniq, cnt = np.unique(arr, return_counts=True)
uniq = [i.item() for i in uniq]
cnt = [i.item() for i in cnt]
print(f"{split} → total={len(y)}, classes={{}}".format(dict(zip(uniq, cnt))))
if __name__ == "__main__":
main()