import os import librosa import h5py import numpy as np import scipy.io as sio import scipy.signal as signal from pathlib import Path from typing import Tuple, List import re import argparse from huggingface_hub import snapshot_download from joblib import Parallel, delayed from tqdm import tqdm def download_emg_only(save_dir: str): repo_id = "MML-Group/AVE-Speech" allow_patterns = [ "Train/EMG/**", "Val/EMG/**", "Test/EMG/**", "phonetic_transcription.xlsx", ] snapshot_download( repo_id=repo_id, repo_type="dataset", local_dir=save_dir, allow_patterns=allow_patterns, ) def unzip_file(zip_path: str, extract_to: str) -> None: import zipfile with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_to) def unzip_all_subjects(base_dir: str): base_path = Path(base_dir) pattern = re.compile(r"subject_(\d+)\.zip") for zip_file in base_path.rglob("*.zip"): match = pattern.search(zip_file.name) if not match: continue subject_id = match.group(1) extract_dir = zip_file.parent / f"subject_{subject_id}" extract_dir.mkdir(exist_ok=True) print(f"Unzipping {zip_file} -> {extract_dir}") unzip_file(str(zip_file), str(extract_dir)) zip_file.unlink() def filter(raw_data): fs=1000 b1, a1 = signal.iirnotch(50, 30, fs) b2, a2 = signal.iirnotch(150, 30, fs) b3, a3 = signal.iirnotch(250, 30, fs) b4, a4 = signal.iirnotch(350, 30, fs) b5, a5 = signal.butter(4, [10/(fs/2), 400/(fs/2)], 'bandpass') x = signal.filtfilt(b1, a1, raw_data, axis=1) x = signal.filtfilt(b2, a2, x, axis=1) x = signal.filtfilt(b3, a3, x, axis=1) x = signal.filtfilt(b4, a4, x, axis=1) x = signal.filtfilt(b5, a5, x, axis=1) return x def zscore(x: np.ndarray) -> np.ndarray: mu = x.mean(axis=1, keepdims=True) std = x.std(axis=1, keepdims=True) + 1e-8 return (x - mu) / std def EMG_MFSC(x): x = x[:,250:,:] n_mels = 36 sr = 1000 channel_list = [] for j in range(x.shape[-1]): mfsc_x = np.zeros((x.shape[0], 36, n_mels)) for i in range(x.shape[0]): # norm_x = x[i, :, j]/np.max(abs(x[i, :, j])) norm_x = np.asfortranarray(x[i, :, j]) tmp = librosa.feature.melspectrogram(y=norm_x, sr=sr, n_mels=n_mels, n_fft=200, hop_length=50) tmp = librosa.power_to_db(tmp).T mfsc_x[i, :, :] = tmp mfsc_x = np.expand_dims(mfsc_x, axis=-1) channel_list.append(mfsc_x) data_x = np.concatenate(channel_list, axis=-1) mu = np.mean(data_x) std = np.std(data_x) data_x = (data_x - mu) / std data_x = data_x.transpose(0,3,1,2) # Shape: (N, C, F, T) return data_x def process_subject(subject_path: Path, use_mfsc: bool) -> Tuple[List[np.ndarray], List[int]]: X_list, y_list = [], [] for mat_file in subject_path.rglob("*.mat"): emg = sio.loadmat(mat_file) # [2000, 6] emg = np.expand_dims(emg["data"], axis=0) # Shape: (1, 2000, 6) emg = filter(emg) if use_mfsc: emg = EMG_MFSC(emg) else: emg = zscore(emg) emg = emg.squeeze(0) # Shape: (2000, 6) emg = emg.transpose(1, 0) # Shape: (6, 2000) [C, T] label = int(mat_file.stem) X_list.append(emg) y_list.append(label) return X_list, y_list def process_dataset( data_dir: str, save_dir: str, use_mfsc: bool, n_jobs: int, ): splits = ["Train", "Val", "Test"] os.makedirs(save_dir, exist_ok=True) for split in splits: split_path = Path(data_dir) / split / "EMG" if not split_path.exists(): continue print(f"\nProcessing {split}...") subjects = [p for p in split_path.iterdir() if p.is_dir()] # Parallel process subjects results = Parallel(n_jobs=n_jobs, backend="loky")( delayed(process_subject)(subj, use_mfsc) for subj in tqdm(subjects) ) X_all, y_all = [], [] for X_list, y_list in results: if X_list is None: continue X_all.extend(X_list) y_all.extend(y_list) if len(X_all) == 0: continue X = np.array(X_all, dtype=np.float32) y = np.array(y_all, dtype=np.int64) # Save to HDF5 with h5py.File(os.path.join(save_dir, f"{split.lower()}.h5"), "w") as f: f.create_dataset("data", data=X) f.create_dataset("label", data=y) print(f"{split}: Processed {len(X)} samples.") print(f"Saved shapes -> X: {X.shape}, y: {y.shape}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, required=True) parser.add_argument("--save_dir", type=str, required=True) parser.add_argument("--download_data", action="store_true") parser.add_argument("--use_mfsc", action="store_true") parser.add_argument("--n_jobs", type=int, default=-1) args = parser.parse_args() os.makedirs(args.data_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True) if args.download_data: print("Downloading dataset...") download_emg_only(args.data_dir) print("Unzipping dataset...") unzip_all_subjects(args.data_dir) print("Processing dataset...") process_dataset( data_dir=args.data_dir, save_dir=args.save_dir, use_mfsc=args.use_mfsc, n_jobs=args.n_jobs )