MatteoFasulo commited on
Commit
45d17fb
Β·
1 Parent(s): 6d283f0

refactor: EMG processing scripts and documentation

Browse files
Files changed (9) hide show
  1. README.md +47 -280
  2. scripts/README.md +26 -125
  3. scripts/db5.py +122 -13
  4. scripts/db6.py +108 -60
  5. scripts/db7.py +108 -58
  6. scripts/db8.py +87 -20
  7. scripts/emg2pose.py +118 -38
  8. scripts/epn.py +118 -16
  9. scripts/uci.py +132 -32
README.md CHANGED
@@ -103,310 +103,77 @@ tags:
103
 
104
  <div align="center">
105
  <img src="https://raw.githubusercontent.com/MatteoFasulo/BioFoundation/refs/heads/TinyMyo/docs/model/logo/TinyMyo_logo.png" alt="TinyMyo Logo" width="400" />
106
- <h1>TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge</h1>
107
- </div>
108
- <p align="center">
109
- <a href="https://github.com/pulp-bio/BioFoundation">
110
- <img src ="https://img.shields.io/github/stars/pulp-bio/BioFoundation?color=ccf" alt="Github">
111
- </a>
112
- <a href="https://creativecommons.org/licenses/by-nd/4.0/">
113
- <img src="https://img.shields.io/badge/License-CC_BY--ND_4.0-lightgrey.svg" alt="License">
114
- </a>
115
- <a href="https://arxiv.org/abs/2512.15729">
116
- <img src="https://img.shields.io/badge/arXiv-2512.15729-b31b1b.svg" alt="Paper">
117
- </a>
118
  </p>
119
 
120
- **TinyMyo** is a **3.6M-parameter** Transformer-based **foundation model for surface EMG (sEMG)**.
121
- It is pretrained on >480 GB of EMG data and optimized for **ultra-low-power, real-time deployment**, including **microcontrollers (GAP9)** where it achieves an inference time of **0.785 s**, energy of **44.91 mJ** and power envelope of **57.18 mW**.
122
-
123
- TinyMyo is built for **broad generalization** across datasets, sensor configurations, movement tasks, subjects, and domains (gesture, kinematics, speech).
124
-
125
- ---
126
-
127
- # πŸ”’ License & Usage (Model Weights)
128
-
129
- The released TinyMyo weights are licensed under **CC BY-ND 4.0**.
130
- This summary is not legal adviceβ€”please read the full license.
131
-
132
- ### βœ… You may
133
-
134
- * **Use** and **redistribute** the **unmodified** TinyMyo weights (including commercially) **with attribution**.
135
- * **Fine-tune/modify internally** for research or production without redistributing modified weights.
136
- * **Publish code, configs, evaluations, and papers** using TinyMyo.
137
-
138
- ### 🚫 You may not
139
-
140
- * **Share or host modified weights** in any form (including LoRA/adapter deltas, pruned/quantized models).
141
- * **Claim endorsement** from the TinyMyo authors without permission.
142
- * **Use the TinyMyo name** for derivative models.
143
-
144
- ### 🀝 Contributing Improvements
145
-
146
- To upstream improvements, submit a **PR** to the
147
- **[BioFoundation repository](https://github.com/pulp-bio/BioFoundation)** with:
148
-
149
- 1. Full reproducibility artifacts (configs, logs, seeds, environment).
150
- 2. Evaluation on standard protocols (e.g., DB5, EPN-612, UCI EMG, DB8, Silent Speech).
151
- 3. Comparison to TinyMyo’s reported metrics.
152
-
153
- Approved PRs will be retrained and released as **official TinyMyo** checkpoints under CC BY-ND.
154
-
155
- ---
156
-
157
- # πŸ”Ž 1. Default Input & Preprocessing
158
-
159
- Unless specified otherwise, TinyMyo expects:
160
-
161
- * **Channels:** 16
162
- * **Sampling rate:** 2000 Hz
163
- * **Segment length:** 1000 samples (0.5 s)
164
- * **Windowing:** 50% overlap (pretraining)
165
- * **Preprocessing:**
166
-
167
- * 4th-order **20–450 Hz bandpass**
168
- * **50 Hz notch filter**
169
- * **Min–max normalization** (pretraining)
170
- * **Z-score normalization** (downstream)
171
-
172
- Datasets with <16 channels are **zero-padded (pretraining only)**.
173
-
174
- ---
175
-
176
- # πŸ”¬ 2. Pretraining Overview
177
-
178
- TinyMyo is pretrained via masked reconstruction on **three large-scale EMG datasets**:
179
-
180
- | Dataset | Subjects | fs | Channels | Size |
181
- | ----------- | -------- | ------- | -------- | ------- |
182
- | Ninapro DB6 | 10 | 2000 Hz | 14 | 20.3 GB |
183
- | Ninapro DB7 | 22 | 2000 Hz | 12 | 30.9 GB |
184
- | EMG2Pose | 192 | 2000 Hz | 16 | 431 GB |
185
-
186
- ## Tokenization: Channel-Independent Patches
187
-
188
- Unlike EEG FMs that mix channels early, TinyMyo uses **per-channel patching**:
189
-
190
- * Patch length: **20 samples**
191
- * Patch stride: **20 samples**
192
- * Tokens/channel: **50**
193
- * Total seq length: **800 tokens** (16 x 50)
194
- * Positional encoding: **RoPE**
195
-
196
- This preserves electrode-specific structure while allowing attention to learn cross-channel relationships.
197
-
198
- ## Transformer Encoder
199
-
200
- * **8 layers**, **3 heads**
201
- * Embedding dim: **192**
202
- * Pre-LayerNorm
203
- * Dropout & drop-path: **0.1**
204
-
205
- ## Lightweight Decoder
206
-
207
- A **single linear layer** (~3.9k params) reconstructs masked patches.
208
- Following SimMIM, this forces the encoder to learn robust latent structure.
209
-
210
- ## Masking Objective
211
-
212
- * **50% random masking** with a learnable `[MASK]` token
213
- * Loss: **Smooth L1** with small penalty on visible patches
214
- $$
215
- \mathcal{L} = \mathcal{L}*{\text{masked}} + 0.1,\mathcal{L}*{\text{visible}}
216
- $$
217
-
218
- ## Training Setup
219
-
220
- * Optimizer: **AdamW** (Ξ²=(0.9,0.98), wd=0.01)
221
- * LR: **1e-4** with cosine decay
222
- * Batch size: **512** (with grad accumulation)
223
- * Epochs: **50**, warm-up: 10
224
- * Hardware: **4Γ— NVIDIA GH200 GPUs**
225
-
226
- ---
227
-
228
- # 🧠 3. Architecture Summary
229
-
230
- ### Model Variant
231
-
232
- | Variant | Params | (Layers, Dim) |
233
- | ------- | -------- | ------------- |
234
- | TinyMyo | **3.6M** | (8, 192) |
235
-
236
- ---
237
-
238
- # 🎯 4. Downstream Tasks
239
-
240
- TinyMyo generalizes across **gesture classification**, **kinematic regression**, and **speech EMG**β€”with state-of-the-art or competitive results.
241
-
242
- ---
243
-
244
- ## 4.1 Hand Gesture Classification
245
-
246
- Evaluated on:
247
-
248
- * **Ninapro DB5** (52 classes, 10 subjects, 200 Hz)
249
- * **EPN-612** (5 classes, 612 subjects, 200 Hz)
250
- * **UCI EMG** (6 classes, 36 subjects, 200 Hz)
251
-
252
- ### Preprocessing
253
-
254
- * EMG filtering: **20–90 Hz bandpass + 50 Hz notch**
255
- * Window sizes:
256
-
257
- * **200 samples** (1 sec, best for DB5)
258
- * **1000 samples** (5 sec, best for EPN, UCI)
259
-
260
- ### Linear Classification Head
261
-
262
- * Input: **C Γ— 192**
263
- * Params: **<40k**
264
-
265
- ### Performance (Fine-tuned)
266
-
267
- | Dataset | Metric | Result |
268
- | ------------------------ | ------ | ----------------- |
269
- | **Ninapro DB5** (1 sec) | Acc | **89.41 Β± 0.16%** |
270
- | **EPN-612** (5 sec) | Acc | **96.74 Β± 0.09%** |
271
- | **UCI EMG** (5 sec) | Acc | **97.56 Β± 0.32%** |
272
-
273
- TinyMyo achieves **new state-of-the-art** on DB5, EPN-612, and UCI.
274
-
275
- ---
276
-
277
- ## 4.2 Hand Kinematic Regression (Ninapro DB8)
278
-
279
- * Predict **5 joint angles**
280
- * Windows: **100 ms** or **500 ms**
281
- * Normalization: z-score only
282
-
283
- ### Regression Head (~788k params)
284
-
285
- * Depthwise + pointwise convs
286
- * Upsampling
287
- * Global average pooling
288
- * Linear projection to 5 outputs
289
-
290
- ### Performance
291
-
292
- * **MAE = 8.77 Β± 0.12Β°** (500 ms)
293
-
294
- Note: Prior works reporting ~6.9Β° MAE are **subject-specific**; TinyMyo trains a **single cross-subject model**, a significantly harder setting.
295
-
296
- ---
297
-
298
- ## 4.3 Speech Production & Recognition (Silent Speech)
299
-
300
- Dataset: **Gaddy Silent Speech**
301
- (8 channels, 1000 Hz, face/neck EMG)
302
-
303
- ### Speech Production (EMG β†’ MFCC β†’ HiFi-GAN β†’ Audio)
304
-
305
- Pipeline:
306
-
307
- 1. Residual downsampling
308
- 2. TinyMyo encoder
309
- 3. Linear projection β†’ **26-dim MFCC**
310
- 4. HiFi-GAN vocoder
311
-
312
- **WER:** **33.54 Β± 1.12%**
313
- β‰ˆ state-of-the-art with **>90% fewer params** in the transduction model.
314
-
315
- ### Speech Recognition (EMG β†’ Text)
316
-
317
- * TinyMyo encoder
318
- * Linear projection β†’ **37 characters**
319
- * **CTC** loss
320
- * 4-gram LM + beam search
321
-
322
- **WER:** **33.95 Β± 0.97%**
323
-
324
- TinyMyo is EMG-only, unlike multimodal systems like MONA-LISA.
325
 
326
  ---
327
 
328
- # ⚑ 5. Edge Deployment (GAP9 MCU)
329
-
330
- TinyMyo runs efficiently on **GAP9 (RISC-V)** via:
331
-
332
- * **INT8 quantization**, including attention
333
- * Multi-level streaming (L3 to L2 to L1)
334
- * Integer LayerNorm, GELU, softmax
335
- * Static memory arena via liveness analysis
336
 
337
- ### Runtime (EPN-612 dataset)
338
 
339
- * **Inference time**: **0.785 s**
340
- * **Energy**: **44.91 mJ**
341
- * **Average power**: **57.18 mW**
342
-
343
- This is the **first EMG foundation model demonstrated on a microcontroller**.
344
-
345
- ---
346
 
347
- # πŸ“Š 6. Results Summary
348
-
349
- ### Pretraining
350
-
351
- * Smooth L1 reconstruction with high fidelity
352
- * Total compute β‰ˆ **4.0 GFLOPs**
353
-
354
- ### Downstream Highlights
355
-
356
- * **DB5:** 89.41%
357
- * **EPN-612:** 96.74%
358
- * **UCI EMG:** 97.56%
359
- * **Neuromotor:** 0.153 CLER
360
- * **DB8 Regression:** MAE 8.77Β°
361
- * **Silent Speech Production:** 33.54% WER
362
- * **Silent Speech Recognition:** 33.95% WER
363
 
364
- TinyMyo matches or exceeds state-of-the-art performance, while being smaller and more efficient than all prior EMG foundation models.
 
 
 
365
 
366
  ---
367
 
368
- # πŸ› οΈ Code & Usage
 
 
 
 
369
 
370
- To fine-tune TinyMyo on downstream tasks, follow the examples in the
371
- **[BioFoundation repository](https://github.com/pulp-bio/BioFoundation)**.
372
 
373
- ```bash
374
- python -u run_train.py +experiment=TinyMyo_finetune \
375
- pretrained_safetensors_path=/path/to/model.safetensors
376
- ```
377
-
378
- Environment variables:
379
-
380
- * `DATA_PATH` β†’ dataset path
381
- * `CHECKPOINT_DIR` β†’ checkpoint to load
382
 
383
  ---
384
 
385
- ## πŸ”— Resources
386
-
387
- - **Code:** https://github.com/pulp-bio/BioFoundation
 
388
 
389
  ---
390
 
391
- # πŸ“œ Citation
392
-
393
- Please cite TinyMyo using:
394
 
395
  ```bibtex
396
- @misc{fasulo2025tinymyotinyfoundationmodel,
397
- title={TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge},
398
  author={Matteo Fasulo and Giusy Spacone and Thorir Mar Ingolfsson and Yawei Li and Luca Benini and Andrea Cossettini},
399
- year={2025},
400
  eprint={2512.15729},
401
  archivePrefix={arXiv},
402
  primaryClass={eess.SP},
403
- url={https://arxiv.org/abs/2512.15729},
404
  }
405
- ```
406
-
407
- ---
408
-
409
- # 🧭 Contact & Support
410
-
411
- * Questions or issues?
412
- Open an issue on the **BioFoundation GitHub repository**.
 
103
 
104
  <div align="center">
105
  <img src="https://raw.githubusercontent.com/MatteoFasulo/BioFoundation/refs/heads/TinyMyo/docs/model/logo/TinyMyo_logo.png" alt="TinyMyo Logo" width="400" />
106
+ <h1>TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge</h1>
107
+ </div>
108
+ <p align="center">
109
+ <a href="https://github.com/pulp-bio/BioFoundation"><img src ="https://img.shields.io/github/stars/pulp-bio/BioFoundation?color=ccf" alt="Github"></a>
110
+ <a href="https://creativecommons.org/licenses/by-nd/4.0/"><img src="https://img.shields.io/badge/License-CC_BY--ND_4.0-lightgrey.svg" alt="License"></a>
111
+ <a href="https://arxiv.org/abs/2512.15729"><img src="https://img.shields.io/badge/arXiv-2512.15729-b31b1b.svg" alt="Paper"></a>
 
 
 
 
 
 
112
  </p>
113
 
114
+ **TinyMyo** is a **3.6M-parameter** Transformer foundation model for surface EMG (sEMG), optimized for ultra-low-power edge deployment (GAP9 MCU). It demonstrates state-of-the-art performance across gesture classification, kinematic regression, and speech synthesis.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  ---
117
 
118
+ ## πŸš€ Quick Start
 
 
 
 
 
 
 
119
 
120
+ TinyMyo is built as a specialized model within the [BioFoundation](https://github.com/pulp-bio/BioFoundation) framework.
121
 
122
+ ### 1. Requirements
123
+ - **Preprocessing:** Dependencies for data scripts are in `scripts/requirements.txt`.
124
+ - **BioFoundation:** Full framework requirements for training/inference are in the [GitHub repository](https://github.com/pulp-bio/BioFoundation/blob/main/requirements.txt).
 
 
 
 
125
 
126
+ ### 2. Preprocessing
127
+ Process raw datasets into HDF5 format:
128
+ ```bash
129
+ python scripts/db5.py --data_dir $DATA_PATH/raw/ --save_dir $DATA_PATH/h5/ --seq_len 200 --stride 50
130
+ ```
131
+ *See [scripts/README.md](scripts/README.md) for all dataset commands.*
 
 
 
 
 
 
 
 
 
 
132
 
133
+ ### 3. Fine-tuning
134
+ ```bash
135
+ python run_train.py +experiment=TinyMyo_finetune pretrained_safetensors_path=/path/to/base.safetensors
136
+ ```
137
 
138
  ---
139
 
140
+ ## 🧠 Architecture & Pretraining
141
+ - **Core:** 8-layer Transformer encoder (192-dim embeddings, 3 heads).
142
+ - **Tokenization:** Channel-independent patching (20 samples/patch) with RoPE.
143
+ - **Data:** Pretrained on >480 GB of EMG (NinaPro DB6/7, EMG2Pose).
144
+ - **Specs:** 3.6M parameters, 4.0 GFLOPs.
145
 
146
+ ## 🎯 Benchmarks
 
147
 
148
+ | Task | Dataset | Metric | TinyMyo |
149
+ | :--- | :--- | :--- | :--- |
150
+ | **Gesture** | NinaPro DB5 | Accuracy | **89.41%** |
151
+ | **Gesture** | EPN-612 | Accuracy | **96.74%** |
152
+ | **Gesture** | UCI EMG | Accuracy | **97.56%** |
153
+ | **Regression**| NinaPro DB8 | MAE | **8.77Β°** |
154
+ | **Speech** | Gaddy (Speech Synthesis) | WER | **33.54%** |
155
+ | **Speech** | Gaddy (Speech Recognition) | WER | **33.95%** |
 
156
 
157
  ---
158
 
159
+ ## ⚑ Edge Performance (GAP9 MCU)
160
+ - **Inference:** 0.785 s
161
+ - **Energy:** 44.91 mJ
162
+ - **Power:** 57.18 mW
163
 
164
  ---
165
 
166
+ ## πŸ“œ License & Citation
167
+ Weights are licensed under **CC BY-ND 4.0**. See [LICENSE](LICENSE) for details.
 
168
 
169
  ```bibtex
170
+ @misc{fasulo2026tinymyotinyfoundationmodel,
171
+ title={TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge},
172
  author={Matteo Fasulo and Giusy Spacone and Thorir Mar Ingolfsson and Yawei Li and Luca Benini and Andrea Cossettini},
173
+ year={2026},
174
  eprint={2512.15729},
175
  archivePrefix={arXiv},
176
  primaryClass={eess.SP},
177
+ url={https://arxiv.org/abs/2512.15729},
178
  }
179
+ ```
 
 
 
 
 
 
 
scripts/README.md CHANGED
@@ -1,137 +1,38 @@
1
- # Dataset Preparation Commands
2
 
3
- ## Overview
4
 
5
- This document provides the commands to prepare various EMG datasets for pretraining and downstream tasks. Each dataset preparation script takes in raw data, processes it into overlapping windows, and saves the processed data in HDF5 format for efficient loading during model training.
 
 
 
 
 
 
6
 
7
- Remember to add the flag `--download_data` if the dataset is not downloaded yet.
8
-
9
- Substitute the `$DATA_PATH` environment variable with your path for saving the dataset.
10
-
11
- The `seq_len` parameter in the scripts corresponds to the window size in samples, and the `stride` parameter corresponds to the step size between windows in samples. The sampling rate for the pretraining datasets is 2 kHz, while for the downstream datasets it is either 200 Hz or 2 kHz depending on the dataset.
12
-
13
- The required libraries for running the scripts are located inside the `requirements.txt` file.
14
 
15
  ## Pretraining Datasets
 
16
 
17
- For the pretraining datasets, we use a window size of 0.5 seconds with a 50% overlap at 2 kHz sampling rate:
18
-
19
- ### emg2pose (0.5 sec, 50% overlap)
20
-
21
- Note: due to the large size of emg2pose dataset, the `--download_data` flag is not available for this dataset.
22
-
23
- ```bash
24
- python scripts/emg2pose.py \
25
- --data_dir $DATA_PATH/datasets/emg2pose_data/ \
26
- --save_dir $DATA_PATH/datasets/emg2pose_data/h5/ \
27
- --seq_len 1000 \
28
- --stride 500
29
- ```
30
-
31
- ### Ninapro DB6 (0.5 sec, 50% overlap)
32
-
33
- ```bash
34
- python scripts/db6.py \
35
- --data_dir $DATA_PATH/datasets/ninapro/DB6/ \
36
- --save_dir $DATA_PATH/datasets/ninapro/DB6/h5/ \
37
- --seq_len 1000 \
38
- --stride 500
39
- ```
40
-
41
- ### Ninapro DB7 (0.5 sec, 50% overlap)
42
-
43
- ```bash
44
- python scripts/db7.py \
45
- --data_dir $DATA_PATH/datasets/ninapro/DB7/ \
46
- --save_dir $DATA_PATH/datasets/ninapro/DB7/h5/ \
47
- --seq_len 1000 \
48
- --stride 500
49
- ```
50
 
51
  ---
52
 
53
  ## Downstream Datasets
54
 
55
- For the downstream tasks, gesture classification is performed on NinaPro DB5, EMG-EPN612, and UCI EMG datasets (200 Hz) while regression is performed on NinaPro DB8 (2 kHz).
56
-
57
- ### Ninapro DB5 (1 sec, 25% overlap)
58
-
59
- ```bash
60
- python scripts/db5.py \
61
- --data_dir $DATA_PATH/datasets/ninapro/DB5/ \
62
- --save_dir $DATA_PATH/datasets/ninapro/DB5/h5/ \
63
- --seq_len 200 \
64
- --stride 50
65
- ```
66
-
67
- ### Ninapro DB5 (5 sec, 25% overlap)
68
-
69
- ```bash
70
- python scripts/db5.py \
71
- --data_dir $DATA_PATH/datasets/ninapro/DB5/ \
72
- --save_dir $DATA_PATH/datasets/ninapro/DB5/h5/ \
73
- --seq_len 1000 \
74
- --stride 250
75
- ```
76
-
77
- ### EMG-EPN612 (1 sec, no overlap)
78
-
79
- ```bash
80
- python scripts/epn.py \
81
- --data_dir $DATA_PATH/datasets/EPN612/ \
82
- --source_training $DATA_PATH/datasets/EPN612/trainingJSON/ \
83
- --source_testing $DATA_PATH/datasets/EPN612/testingJSON/ \
84
- --dest_dir $DATA_PATH/datasets/EPN612/h5/ \
85
- --seq_len 200
86
- ```
87
-
88
- ### EMG-EPN612 (5 sec, no overlap)
89
-
90
- ```bash
91
- python scripts/epn.py \
92
- --data_dir $DATA_PATH/datasets/EPN612/ \
93
- --source_training $DATA_PATH/datasets/EPN612/trainingJSON/ \
94
- --source_testing $DATA_PATH/datasets/EPN612/testingJSON/ \
95
- --dest_dir $DATA_PATH/datasets/EPN612/h5/ \
96
- --seq_len 1000
97
- ```
98
-
99
- ### UCI EMG (1 sec, 25% overlap)
100
-
101
- ```bash
102
- python scripts/uci.py \
103
- --data_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
104
- --save_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
105
- --seq_len 200 \
106
- --stride 50
107
- ```
108
-
109
- ### UCI EMG (5 sec, 25% overlap)
110
-
111
- ```bash
112
- python scripts/uci.py \
113
- --data_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
114
- --save_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
115
- --seq_len 1000 \
116
- --stride 250
117
- ```
118
-
119
- ### Ninapro DB8 (100 ms, no overlap)
120
-
121
- ```bash
122
- python scripts/db8.py \
123
- --data_dir $DATA_PATH/datasets/ninapro/DB8/ \
124
- --save_dir $DATA_PATH/datasets/ninapro/DB8/h5/ \
125
- --seq_len 200 \
126
- --stride 200
127
- ```
128
-
129
- ### Ninapro DB8 (500 ms, no overlap)
130
 
131
- ```bash
132
- python scripts/db8.py \
133
- --data_dir $DATA_PATH/datasets/ninapro/DB8/ \
134
- --save_dir $DATA_PATH/datasets/ninapro/DB8/h5/ \
135
- --seq_len 1000 \
136
- --stride 1000
137
- ```
 
1
+ # Dataset Preparation
2
 
3
+ This guide provides commands to process raw EMG data into HDF5 format using sliding windows.
4
 
5
+ ### Usage
6
+ - **Dependencies:** Install requirements specific to these scripts via `pip install -r scripts/requirements.txt`. Framework requirements for TinyMyo are in the [BioFoundation repository](https://github.com/pulp-bio/BioFoundation).
7
+ - Use `--download_data` if raw data is missing.
8
+ - Replace `$DATA_PATH` with your local storage path.
9
+ - `seq_len`: Window size (samples).
10
+ - `stride`: Step size (samples).
11
+ - Pretraining scripts use **2 kHz** sampling. Downstream scripts use **200 Hz** or **2 kHz**.
12
 
13
+ ---
 
 
 
 
 
 
14
 
15
  ## Pretraining Datasets
16
+ (0.5s windows, 50% overlap @ 2 kHz)
17
 
18
+ | Dataset | Size (GB) | Seq Len | Stride | Command |
19
+ | :--- | :--- | :--- | :--- | :--- |
20
+ | **EMG2Pose** | 431 | 1000 (0.5s) | 500 | `python scripts/emg2pose.py --data_dir $DATA_PATH/emg2pose_data/ --save_dir $DATA_PATH/emg2pose_data/h5/ --seq_len 1000 --stride 500` |
21
+ | **NinaPro DB6** | ~20 | 1000 (0.5s) | 500 | `python scripts/db6.py --data_dir $DATA_PATH/ninapro/DB6/ --save_dir $DATA_PATH/ninapro/DB6/h5/ --seq_len 1000 --stride 500` |
22
+ | **NinaPro DB7** | ~10 | 1000 (0.5s) | 500 | `python scripts/db7.py --data_dir $DATA_PATH/ninapro/DB7/ --save_dir $DATA_PATH/ninapro/DB7/h5/ --seq_len 1000 --stride 500` |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  ---
25
 
26
  ## Downstream Datasets
27
 
28
+ | Dataset | Metric | Seq Len | Stride | Command |
29
+ | :--- | :--- | :--- | :--- | :--- |
30
+ | **NinaPro DB5** | Gesture | 200 (1s) | 50 | `python scripts/db5.py --data_dir $DATA_PATH/ninapro/DB5/ --save_dir $DATA_PATH/ninapro/DB5/h5/ --seq_len 200 --stride 50` |
31
+ | **NinaPro DB5** | Gesture | 1000 (5s) | 250 | `python scripts/db5.py --data_dir $DATA_PATH/ninapro/DB5/ --save_dir $DATA_PATH/ninapro/DB5/h5/ --seq_len 1000 --stride 250` |
32
+ | **EMG-EPN612** | Gesture | 200 (1s) | N/A | `python scripts/epn.py --data_dir $DATA_PATH/EPN612/ --source_training $DATA_PATH/EPN612/trainingJSON/ --source_testing $DATA_PATH/EPN612/testingJSON/ --dest_dir $DATA_PATH/EPN612/h5/ --seq_len 200` |
33
+ | **EMG-EPN612** | Gesture | 1000 (5s) | N/A | `python scripts/epn.py --data_dir $DATA_PATH/EPN612/ --source_training $DATA_PATH/EPN612/trainingJSON/ --source_testing $DATA_PATH/EPN612/testingJSON/ --dest_dir $DATA_PATH/EPN612/h5/ --seq_len 1000` |
34
+ | **UCI EMG** | Gesture | 200 (1s) | 50 | `python scripts/uci.py --data_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/ --save_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/h5/ --seq_len 200 --stride 50` |
35
+ | **UCI EMG** | Gesture | 1000 (5s) | 250 | `python scripts/uci.py --data_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/ --save_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/h5/ --seq_len 1000 --stride 250` |
36
+ | **NinaPro DB8** | Regression | 200 (0.1s) | 200 | `python scripts/db8.py --data_dir $DATA_PATH/ninapro/DB8/ --save_dir $DATA_PATH/ninapro/DB8/h5/ --seq_len 200 --stride 200` |
37
+ | **NinaPro DB8** | Regression | 1000 (0.5s) | 1000 | `python scripts/db8.py --data_dir $DATA_PATH/ninapro/DB8/ --save_dir $DATA_PATH/ninapro/DB8/h5/ --seq_len 1000 --stride 1000` |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
scripts/db5.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
 
4
  import h5py
5
  import numpy as np
@@ -7,30 +8,77 @@ import scipy.io
7
  import scipy.signal as signal
8
  from scipy.signal import iirnotch
9
 
10
- sequence_to_seconds = lambda seq_len, fs: seq_len / fs
 
11
 
 
 
 
12
 
13
- # ==== Data augmentation functions ====
14
- def random_amplitude_scale(sig, scale_range=(0.9, 1.1)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  scale = np.random.uniform(*scale_range)
16
  return sig * scale
17
 
18
 
19
- def random_time_jitter(sig, jitter_ratio=0.01):
 
 
 
 
 
 
 
 
 
 
20
  T, D = sig.shape
21
  std_ch = np.std(sig, axis=0)
22
  noise = np.random.randn(T, D) * (jitter_ratio * std_ch)
23
  return sig + noise
24
 
25
 
26
- def random_channel_dropout(sig, dropout_prob=0.05):
 
 
 
 
 
 
 
 
 
 
27
  T, D = sig.shape
28
  mask = np.random.rand(D) < dropout_prob
29
  sig[:, mask] = 0.0
30
  return sig
31
 
32
 
33
- def augment_one_sample(seg):
 
 
 
 
 
 
 
 
34
  out = seg.copy()
35
  out = random_amplitude_scale(out, (0.9, 1.1))
36
  out = random_time_jitter(out, 0.01)
@@ -38,7 +86,20 @@ def augment_one_sample(seg):
38
  return out
39
 
40
 
41
- def augment_train_data(data, labels, factor=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if factor <= 0 or data.shape[0] == 0:
43
  return data, labels
44
  aug_segs = [data]
@@ -55,8 +116,19 @@ def augment_train_data(data, labels, factor=3):
55
  return new_data, new_labels
56
 
57
 
58
- # ==== Filter functions (operate at original fs=200) ====
59
- def notch_filter(data, notch_freq=50.0, Q=30.0, fs=200.0):
 
 
 
 
 
 
 
 
 
 
 
60
  b, a = iirnotch(notch_freq, Q, fs)
61
  out = np.zeros_like(data)
62
  for ch in range(data.shape[1]):
@@ -64,7 +136,25 @@ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=200.0):
64
  return out
65
 
66
 
67
- def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  nyq = 0.5 * fs
69
  low = lowcut / nyq
70
  high = highcut / nyq
@@ -75,8 +165,28 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
75
  return out
76
 
77
 
78
- # ==== Window segmentation ====
79
- def process_emg_features(emg, label, rerep, window_size=1024, stride=512):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  segs, lbls, reps = [], [], []
81
  N = len(label)
82
  for start in range(0, N, stride):
@@ -94,7 +204,6 @@ def process_emg_features(emg, label, rerep, window_size=1024, stride=512):
94
  return np.array(segs), np.array(lbls), np.array(reps)
95
 
96
 
97
- # ==== Main pipeline ====
98
  def main():
99
  import argparse
100
 
 
1
  import os
2
  import sys
3
+ from typing import Tuple, List, Optional, Union, Dict, Any, Callable
4
 
5
  import h5py
6
  import numpy as np
 
8
  import scipy.signal as signal
9
  from scipy.signal import iirnotch
10
 
11
+ def sequence_to_seconds(seq_len: int, fs: float) -> float:
12
+ """Converts a sequence length in samples to time in seconds.
13
 
14
+ Args:
15
+ seq_len (int): The number of samples in the sequence.
16
+ fs (float): The sampling frequency in Hz.
17
 
18
+ Returns:
19
+ float: The duration of the sequence in seconds.
20
+ """
21
+ return seq_len / fs
22
+
23
+
24
+ def random_amplitude_scale(sig: np.ndarray, scale_range: Tuple[float, float] = (0.9, 1.1)) -> np.ndarray:
25
+ """Applies random amplitude scaling to the input signal.
26
+
27
+ Args:
28
+ sig (np.ndarray): The input signal array of shape (T, D).
29
+ scale_range (Tuple[float, float], optional): The range [min, max] for the scaling factor.
30
+ Defaults to (0.9, 1.1).
31
+
32
+ Returns:
33
+ np.ndarray: The scaled signal array.
34
+ """
35
  scale = np.random.uniform(*scale_range)
36
  return sig * scale
37
 
38
 
39
+ def random_time_jitter(sig: np.ndarray, jitter_ratio: float = 0.01) -> np.ndarray:
40
+ """Adds random Gaussian noise (jitter) to the input signal.
41
+
42
+ Args:
43
+ sig (np.ndarray): The input signal array of shape (T, D).
44
+ jitter_ratio (float, optional): The ratio to scale the noise relative to
45
+ each channel's standard deviation. Defaults to 0.01.
46
+
47
+ Returns:
48
+ np.ndarray: The signal with added jitter.
49
+ """
50
  T, D = sig.shape
51
  std_ch = np.std(sig, axis=0)
52
  noise = np.random.randn(T, D) * (jitter_ratio * std_ch)
53
  return sig + noise
54
 
55
 
56
+ def random_channel_dropout(sig: np.ndarray, dropout_prob: float = 0.05) -> np.ndarray:
57
+ """Randomly zeros out channels in the signal based on a probability.
58
+
59
+ Args:
60
+ sig (np.ndarray): The input signal array of shape (T, D).
61
+ dropout_prob (float, optional): Probability of dropping each channel.
62
+ Defaults to 0.05.
63
+
64
+ Returns:
65
+ np.ndarray: The signal with dropped channels.
66
+ """
67
  T, D = sig.shape
68
  mask = np.random.rand(D) < dropout_prob
69
  sig[:, mask] = 0.0
70
  return sig
71
 
72
 
73
+ def augment_one_sample(seg: np.ndarray) -> np.ndarray:
74
+ """Applies a sequence of random augmentations to a single signal segment.
75
+
76
+ Args:
77
+ seg (np.ndarray): Single signal segment of shape (window_size, n_ch).
78
+
79
+ Returns:
80
+ np.ndarray: The augmented signal segment.
81
+ """
82
  out = seg.copy()
83
  out = random_amplitude_scale(out, (0.9, 1.1))
84
  out = random_time_jitter(out, 0.01)
 
86
  return out
87
 
88
 
89
+ def augment_train_data(data: np.ndarray, labels: np.ndarray, factor: int = 3) -> Tuple[np.ndarray, np.ndarray]:
90
+ """Augments the training dataset by creating multiple versions of each sample.
91
+
92
+ Args:
93
+ data (np.ndarray): The input dataset of shape (N, window_size, n_ch).
94
+ labels (np.ndarray): The corresponding labels of shape (N,).
95
+ factor (int, optional): The number of augmented versions to create for each sample.
96
+ Defaults to 3.
97
+
98
+ Returns:
99
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
100
+ - The augmented dataset.
101
+ - The augmented labels.
102
+ """
103
  if factor <= 0 or data.shape[0] == 0:
104
  return data, labels
105
  aug_segs = [data]
 
116
  return new_data, new_labels
117
 
118
 
119
+ def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 200.0) -> np.ndarray:
120
+ """Applies a notch filter to remove power line interference.
121
+
122
+ Args:
123
+ data (np.ndarray): The input signal array of shape (T, D).
124
+ notch_freq (float, optional): The frequency to be removed (e.g., 50Hz or 60Hz).
125
+ Defaults to 50.0.
126
+ Q (float, optional): The quality factor. Defaults to 30.0.
127
+ fs (float, optional): The sampling frequency of the signal. Defaults to 200.0.
128
+
129
+ Returns:
130
+ np.ndarray: The filtered signal array.
131
+ """
132
  b, a = iirnotch(notch_freq, Q, fs)
133
  out = np.zeros_like(data)
134
  for ch in range(data.shape[1]):
 
136
  return out
137
 
138
 
139
+ def bandpass_filter_emg(
140
+ emg: np.ndarray,
141
+ lowcut: float = 20.0,
142
+ highcut: float = 90.0,
143
+ fs: float = 200.0,
144
+ order: int = 4
145
+ ) -> np.ndarray:
146
+ """Applies a Butterworth bandpass filter to the EMG signal.
147
+
148
+ Args:
149
+ emg (np.ndarray): The input signal array of shape (T, D).
150
+ lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
151
+ highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
152
+ fs (float, optional): The sampling frequency of the signal. Defaults to 200.0.
153
+ order (int, optional): The order of the filter. Defaults to 4.
154
+
155
+ Returns:
156
+ np.ndarray: The bandpass filtered signal array.
157
+ """
158
  nyq = 0.5 * fs
159
  low = lowcut / nyq
160
  high = highcut / nyq
 
165
  return out
166
 
167
 
168
+ def process_emg_features(
169
+ emg: np.ndarray,
170
+ label: np.ndarray,
171
+ rerep: np.ndarray,
172
+ window_size: int = 1024,
173
+ stride: int = 512
174
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
175
+ """Segments raw EMG signals into overlapping windows.
176
+
177
+ Args:
178
+ emg (np.ndarray): Raw EMG data of shape (T, n_ch).
179
+ label (np.ndarray): Gesture labels of shape (T,).
180
+ rerep (np.ndarray): Repetition indices of shape (T,).
181
+ window_size (int, optional): Number of samples per window. Defaults to 1024.
182
+ stride (int, optional): Number of samples to shift between windows. Defaults to 512.
183
+
184
+ Returns:
185
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
186
+ - windowed segments (N, window_size, n_ch).
187
+ - labels for each window (N,).
188
+ - repetition indices for each window (N,).
189
+ """
190
  segs, lbls, reps = [], [], []
191
  N = len(label)
192
  for start in range(0, N, stride):
 
204
  return np.array(segs), np.array(lbls), np.array(reps)
205
 
206
 
 
207
  def main():
208
  import argparse
209
 
scripts/db6.py CHANGED
@@ -5,12 +5,13 @@ import h5py
5
  import numpy as np
6
  import scipy.io
7
  import scipy.signal as signal
 
8
  from scipy.signal import iirnotch
 
9
 
10
  sequence_to_seconds = lambda seq_len, fs: seq_len / fs
11
 
12
 
13
- # ─────────────── Filtering ──────────────────
14
  def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
15
  """Notch-filter every channel independently."""
16
  b, a = iirnotch(notch_freq, Q, fs)
@@ -29,7 +30,6 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
29
  return out
30
 
31
 
32
- # ─────────────── Sliding window ──────────────
33
  def sliding_window_segment(emg, label, rerepetition, window_size, stride):
34
  """
35
  Segment EMG with a sliding window.
@@ -49,7 +49,64 @@ def sliding_window_segment(emg, label, rerepetition, window_size, stride):
49
  return np.array(segments), np.array(labels), np.array(reps)
50
 
51
 
52
- # ─────────────── Main pipeline ───────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def main():
54
  import argparse
55
 
@@ -65,6 +122,18 @@ def main():
65
  type=int,
66
  help="Step size between windows in samples for segmentation.",
67
  )
 
 
 
 
 
 
 
 
 
 
 
 
68
  args = args.parse_args()
69
  data_dir = args.data_dir # input folder with .mat files
70
  save_dir = args.save_dir # output folder for .h5 files
@@ -105,62 +174,30 @@ def main():
105
  "test": {"data": [], "label": []},
106
  }
107
 
108
- # iterate subjects
109
- for subj in sorted(os.listdir(data_dir)):
110
- subj_path = os.path.join(data_dir, subj)
111
- if not os.path.isdir(subj_path):
112
- continue
113
- print(f"Processing subject {subj} ...")
114
-
115
- subj_seg, subj_lbl, subj_rep = [], [], []
 
 
 
 
 
 
 
 
 
 
116
 
117
- # iterate .mat files
118
- for mat_file in sorted(os.listdir(subj_path)):
119
- if not mat_file.endswith(".mat"):
 
120
  continue
121
- mat_path = os.path.join(subj_path, mat_file)
122
- mat = scipy.io.loadmat(mat_path)
123
-
124
- emg = mat["emg"] # (N, 16)
125
- label = mat["restimulus"].ravel()
126
- rerep = mat["rerepetition"].ravel()
127
-
128
- # drop empty channels (index 8, 9 β†’ 0-based)
129
- emg = np.delete(emg, [8, 9], axis=1) # now (N, 14)
130
-
131
- # filtering
132
- emg = bandpass_filter_emg(emg, 20, 450, fs=fs)
133
- emg = notch_filter(emg, 50, 30, fs=fs)
134
-
135
- # z-score per channel
136
- mu = emg.mean(axis=0)
137
- sd = emg.std(axis=0, ddof=1)
138
- sd[sd == 0] = 1.0
139
- emg = (emg - mu) / sd
140
-
141
- # windowing
142
- seg, lbl, rep = sliding_window_segment(
143
- emg, label, rerep, window_size, stride
144
- )
145
- subj_seg.append(seg)
146
- subj_lbl.append(lbl)
147
- subj_rep.append(rep)
148
-
149
- if not subj_seg:
150
- continue
151
-
152
- seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
153
- lbl = np.concatenate(subj_lbl)
154
- rep = np.concatenate(subj_rep)
155
-
156
- # split by repetition id
157
- for split_name, mask in (
158
- ("train", np.isin(rep, train_reps)),
159
- ("val", np.isin(rep, val_reps)),
160
- ("test", np.isin(rep, test_reps)),
161
- ):
162
- X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
163
- y = lbl[mask]
164
  splits[split_name]["data"].append(X)
165
  splits[split_name]["label"].append(y)
166
 
@@ -177,9 +214,20 @@ def main():
177
  else np.empty((0,), dtype=int)
178
  )
179
 
180
- with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as f:
181
- f.create_dataset("data", data=X.astype(np.float32))
182
- f.create_dataset("label", data=y.astype(np.int64))
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  uniq, cnt = np.unique(y, return_counts=True)
185
  print(f"\n{split.upper()} β†’ X={X.shape}, label distribution:")
 
5
  import numpy as np
6
  import scipy.io
7
  import scipy.signal as signal
8
+ from joblib import Parallel, delayed
9
  from scipy.signal import iirnotch
10
+ from tqdm import tqdm
11
 
12
  sequence_to_seconds = lambda seq_len, fs: seq_len / fs
13
 
14
 
 
15
  def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
16
  """Notch-filter every channel independently."""
17
  b, a = iirnotch(notch_freq, Q, fs)
 
30
  return out
31
 
32
 
 
33
  def sliding_window_segment(emg, label, rerepetition, window_size, stride):
34
  """
35
  Segment EMG with a sliding window.
 
49
  return np.array(segments), np.array(labels), np.array(reps)
50
 
51
 
52
+ def process_subject(
53
+ subj_path,
54
+ window_size,
55
+ stride,
56
+ fs,
57
+ train_reps,
58
+ val_reps,
59
+ test_reps,
60
+ ):
61
+ subj_seg, subj_lbl, subj_rep = [], [], []
62
+
63
+ for mat_file in sorted(os.listdir(subj_path)):
64
+ if not mat_file.endswith(".mat"):
65
+ continue
66
+ mat_path = os.path.join(subj_path, mat_file)
67
+ mat = scipy.io.loadmat(mat_path)
68
+
69
+ emg = mat["emg"] # (N, 16)
70
+ label = mat["restimulus"].ravel()
71
+ rerep = mat["rerepetition"].ravel()
72
+
73
+ emg = np.delete(emg, [8, 9], axis=1) # now (N, 14)
74
+ emg = bandpass_filter_emg(emg, 20, 450, fs=fs)
75
+ emg = notch_filter(emg, 50, 30, fs=fs)
76
+
77
+ mu = emg.mean(axis=0)
78
+ sd = emg.std(axis=0, ddof=1)
79
+ sd[sd == 0] = 1.0
80
+ emg = (emg - mu) / sd
81
+
82
+ seg, lbl, rep = sliding_window_segment(emg, label, rerep, window_size, stride)
83
+ subj_seg.append(seg)
84
+ subj_lbl.append(lbl)
85
+ subj_rep.append(rep)
86
+
87
+ if not subj_seg:
88
+ return {
89
+ "train": (np.empty((0, 14, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
90
+ "val": (np.empty((0, 14, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
91
+ "test": (np.empty((0, 14, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
92
+ }
93
+
94
+ seg = np.concatenate(subj_seg, axis=0)
95
+ lbl = np.concatenate(subj_lbl)
96
+ rep = np.concatenate(subj_rep)
97
+
98
+ out = {}
99
+ for split_name, mask in (
100
+ ("train", np.isin(rep, train_reps)),
101
+ ("val", np.isin(rep, val_reps)),
102
+ ("test", np.isin(rep, test_reps)),
103
+ ):
104
+ X = seg[mask].transpose(0, 2, 1).astype(np.float32)
105
+ y = lbl[mask].astype(np.int64)
106
+ out[split_name] = (X, y)
107
+ return out
108
+
109
+
110
  def main():
111
  import argparse
112
 
 
122
  type=int,
123
  help="Step size between windows in samples for segmentation.",
124
  )
125
+ args.add_argument(
126
+ "--group_size",
127
+ type=int,
128
+ default=1000,
129
+ help="Number of samples per group in the output HDF5 file.",
130
+ )
131
+ args.add_argument(
132
+ "--n_jobs",
133
+ type=int,
134
+ default=-1,
135
+ help="Number of subjects to process in parallel. -1 means all cores.",
136
+ )
137
  args = args.parse_args()
138
  data_dir = args.data_dir # input folder with .mat files
139
  save_dir = args.save_dir # output folder for .h5 files
 
174
  "test": {"data": [], "label": []},
175
  }
176
 
177
+ subject_paths = [
178
+ os.path.join(data_dir, subj)
179
+ for subj in sorted(os.listdir(data_dir))
180
+ if os.path.isdir(os.path.join(data_dir, subj))
181
+ ]
182
+
183
+ subject_results = Parallel(n_jobs=args.n_jobs)(
184
+ delayed(process_subject)(
185
+ subj_path,
186
+ window_size,
187
+ stride,
188
+ fs,
189
+ train_reps,
190
+ val_reps,
191
+ test_reps,
192
+ )
193
+ for subj_path in tqdm(subject_paths, desc="Processing subjects")
194
+ )
195
 
196
+ for result in subject_results:
197
+ for split_name in ["train", "val", "test"]:
198
+ X, y = result[split_name]
199
+ if X.shape[0] == 0:
200
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  splits[split_name]["data"].append(X)
202
  splits[split_name]["label"].append(y)
203
 
 
214
  else np.empty((0,), dtype=int)
215
  )
216
 
217
+ out_path = os.path.join(save_dir, f"{split}.h5")
218
+ if os.path.exists(out_path):
219
+ os.remove(out_path)
220
+ print(f"Removed existing file {out_path} to avoid overwrite issues.")
221
+
222
+ with h5py.File(out_path, "w") as h5f:
223
+ for group_idx, start in enumerate(range(0, X.shape[0], args.group_size)):
224
+ end = min(start + args.group_size, X.shape[0])
225
+ x_chunk = X[start:end].astype(np.float32)
226
+ y_chunk = y[start:end].astype(np.int64)
227
+
228
+ grp = h5f.create_group(f"data_group_{group_idx}")
229
+ grp.create_dataset("X", data=x_chunk)
230
+ grp.create_dataset("y", data=y_chunk)
231
 
232
  uniq, cnt = np.unique(y, return_counts=True)
233
  print(f"\n{split.upper()} β†’ X={X.shape}, label distribution:")
scripts/db7.py CHANGED
@@ -5,12 +5,13 @@ import h5py
5
  import numpy as np
6
  import scipy.io
7
  import scipy.signal as signal
 
8
  from scipy.signal import iirnotch
 
9
 
10
  sequence_to_seconds = lambda seq_len, fs: seq_len / fs
11
 
12
 
13
- # ─────────────── Filtering ──────────────────
14
  def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
15
  """Notch-filter every channel independently."""
16
  b, a = iirnotch(notch_freq, Q, fs)
@@ -29,7 +30,6 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
29
  return out
30
 
31
 
32
- # ─────────────── Sliding window ──────────────
33
  def sliding_window_segment(emg, label, rerepetition, window_size, stride):
34
  """
35
  Segment EMG with a sliding window.
@@ -49,7 +49,63 @@ def sliding_window_segment(emg, label, rerepetition, window_size, stride):
49
  return np.array(segments), np.array(labels), np.array(reps)
50
 
51
 
52
- # ─────────────── Main pipeline ───────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def main():
54
  import argparse
55
 
@@ -65,6 +121,18 @@ def main():
65
  type=int,
66
  help="Step size between windows in samples for segmentation.",
67
  )
 
 
 
 
 
 
 
 
 
 
 
 
68
  args = args.parse_args()
69
  data_dir = args.data_dir # input folder with .mat files
70
  save_dir = args.save_dir # output folder for .h5 files
@@ -100,59 +168,30 @@ def main():
100
  "test": {"data": [], "label": []},
101
  }
102
 
103
- # iterate subjects
104
- for subj in sorted(os.listdir(data_dir)):
105
- subj_path = os.path.join(data_dir, subj)
106
- if not os.path.isdir(subj_path):
107
- continue
108
- print(f"Processing subject {subj} ...")
109
-
110
- subj_seg, subj_lbl, subj_rep = [], [], []
 
 
 
 
 
 
 
 
 
 
111
 
112
- # iterate .mat files
113
- for mat_file in sorted(os.listdir(subj_path)):
114
- if not mat_file.endswith(".mat"):
 
115
  continue
116
- mat_path = os.path.join(subj_path, mat_file)
117
- mat = scipy.io.loadmat(mat_path)
118
-
119
- emg = mat["emg"] # (N, 16)
120
- label = mat["restimulus"].ravel()
121
- rerep = mat["rerepetition"].ravel()
122
-
123
- # filtering
124
- emg = bandpass_filter_emg(emg, 20.0, 450.0, fs=fs)
125
- emg = notch_filter(emg, 50.0, 30.0, fs=fs)
126
-
127
- # z-score per channel
128
- mu = emg.mean(axis=0)
129
- sd = emg.std(axis=0, ddof=1)
130
- sd[sd == 0] = 1.0
131
- emg = (emg - mu) / sd
132
-
133
- # windowing
134
- seg, lbl, rep = sliding_window_segment(
135
- emg, label, rerep, window_size, stride
136
- )
137
- subj_seg.append(seg)
138
- subj_lbl.append(lbl)
139
- subj_rep.append(rep)
140
-
141
- if not subj_seg:
142
- continue
143
-
144
- seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
145
- lbl = np.concatenate(subj_lbl)
146
- rep = np.concatenate(subj_rep)
147
-
148
- # split by repetition id
149
- for split_name, mask in (
150
- ("train", np.isin(rep, train_reps)),
151
- ("val", np.isin(rep, val_reps)),
152
- ("test", np.isin(rep, test_reps)),
153
- ):
154
- X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
155
- y = lbl[mask]
156
  splits[split_name]["data"].append(X)
157
  splits[split_name]["label"].append(y)
158
 
@@ -161,7 +200,7 @@ def main():
161
  X = (
162
  np.concatenate(splits[split]["data"], axis=0)
163
  if splits[split]["data"]
164
- else np.empty((0, 14, window_size))
165
  )
166
  y = (
167
  np.concatenate(splits[split]["label"], axis=0)
@@ -169,9 +208,20 @@ def main():
169
  else np.empty((0,), dtype=int)
170
  )
171
 
172
- with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as f:
173
- f.create_dataset("data", data=X.astype(np.float32))
174
- f.create_dataset("label", data=y.astype(np.int64))
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  uniq, cnt = np.unique(y, return_counts=True)
177
  print(f"\n{split.upper()} β†’ X={X.shape}, label distribution:")
 
5
  import numpy as np
6
  import scipy.io
7
  import scipy.signal as signal
8
+ from joblib import Parallel, delayed
9
  from scipy.signal import iirnotch
10
+ from tqdm import tqdm
11
 
12
  sequence_to_seconds = lambda seq_len, fs: seq_len / fs
13
 
14
 
 
15
  def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
16
  """Notch-filter every channel independently."""
17
  b, a = iirnotch(notch_freq, Q, fs)
 
30
  return out
31
 
32
 
 
33
  def sliding_window_segment(emg, label, rerepetition, window_size, stride):
34
  """
35
  Segment EMG with a sliding window.
 
49
  return np.array(segments), np.array(labels), np.array(reps)
50
 
51
 
52
+ def process_subject(
53
+ subj_path,
54
+ window_size,
55
+ stride,
56
+ fs,
57
+ train_reps,
58
+ val_reps,
59
+ test_reps,
60
+ ):
61
+ subj_seg, subj_lbl, subj_rep = [], [], []
62
+
63
+ for mat_file in sorted(os.listdir(subj_path)):
64
+ if not mat_file.endswith(".mat"):
65
+ continue
66
+ mat_path = os.path.join(subj_path, mat_file)
67
+ mat = scipy.io.loadmat(mat_path)
68
+
69
+ emg = mat["emg"] # (N, 16)
70
+ label = mat["restimulus"].ravel()
71
+ rerep = mat["rerepetition"].ravel()
72
+
73
+ emg = bandpass_filter_emg(emg, 20.0, 450.0, fs=fs)
74
+ emg = notch_filter(emg, 50.0, 30.0, fs=fs)
75
+
76
+ mu = emg.mean(axis=0)
77
+ sd = emg.std(axis=0, ddof=1)
78
+ sd[sd == 0] = 1.0
79
+ emg = (emg - mu) / sd
80
+
81
+ seg, lbl, rep = sliding_window_segment(emg, label, rerep, window_size, stride)
82
+ subj_seg.append(seg)
83
+ subj_lbl.append(lbl)
84
+ subj_rep.append(rep)
85
+
86
+ if not subj_seg:
87
+ return {
88
+ "train": (np.empty((0, 16, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
89
+ "val": (np.empty((0, 16, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
90
+ "test": (np.empty((0, 16, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
91
+ }
92
+
93
+ seg = np.concatenate(subj_seg, axis=0)
94
+ lbl = np.concatenate(subj_lbl)
95
+ rep = np.concatenate(subj_rep)
96
+
97
+ out = {}
98
+ for split_name, mask in (
99
+ ("train", np.isin(rep, train_reps)),
100
+ ("val", np.isin(rep, val_reps)),
101
+ ("test", np.isin(rep, test_reps)),
102
+ ):
103
+ X = seg[mask].transpose(0, 2, 1).astype(np.float32)
104
+ y = lbl[mask].astype(np.int64)
105
+ out[split_name] = (X, y)
106
+ return out
107
+
108
+
109
  def main():
110
  import argparse
111
 
 
121
  type=int,
122
  help="Step size between windows in samples for segmentation.",
123
  )
124
+ args.add_argument(
125
+ "--group_size",
126
+ type=int,
127
+ default=1000,
128
+ help="Number of samples per group in the output HDF5 file.",
129
+ )
130
+ args.add_argument(
131
+ "--n_jobs",
132
+ type=int,
133
+ default=-1,
134
+ help="Number of subjects to process in parallel. -1 means all cores.",
135
+ )
136
  args = args.parse_args()
137
  data_dir = args.data_dir # input folder with .mat files
138
  save_dir = args.save_dir # output folder for .h5 files
 
168
  "test": {"data": [], "label": []},
169
  }
170
 
171
+ subject_paths = [
172
+ os.path.join(data_dir, subj)
173
+ for subj in sorted(os.listdir(data_dir))
174
+ if os.path.isdir(os.path.join(data_dir, subj))
175
+ ]
176
+
177
+ subject_results = Parallel(n_jobs=args.n_jobs)(
178
+ delayed(process_subject)(
179
+ subj_path,
180
+ window_size,
181
+ stride,
182
+ fs,
183
+ train_reps,
184
+ val_reps,
185
+ test_reps,
186
+ )
187
+ for subj_path in tqdm(subject_paths, desc="Processing subjects")
188
+ )
189
 
190
+ for result in subject_results:
191
+ for split_name in ["train", "val", "test"]:
192
+ X, y = result[split_name]
193
+ if X.shape[0] == 0:
194
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  splits[split_name]["data"].append(X)
196
  splits[split_name]["label"].append(y)
197
 
 
200
  X = (
201
  np.concatenate(splits[split]["data"], axis=0)
202
  if splits[split]["data"]
203
+ else np.empty((0, 16, window_size))
204
  )
205
  y = (
206
  np.concatenate(splits[split]["label"], axis=0)
 
208
  else np.empty((0,), dtype=int)
209
  )
210
 
211
+ out_path = os.path.join(save_dir, f"{split}.h5")
212
+ if os.path.exists(out_path):
213
+ os.remove(out_path)
214
+ print(f"Removed existing file {out_path} to avoid overwrite issues.")
215
+
216
+ with h5py.File(out_path, "w") as h5f:
217
+ for group_idx, start in enumerate(range(0, X.shape[0], args.group_size)):
218
+ end = min(start + args.group_size, X.shape[0])
219
+ x_chunk = X[start:end].astype(np.float32)
220
+ y_chunk = y[start:end].astype(np.int64)
221
+
222
+ grp = h5f.create_group(f"data_group_{group_idx}")
223
+ grp.create_dataset("X", data=x_chunk)
224
+ grp.create_dataset("y", data=y_chunk)
225
 
226
  uniq, cnt = np.unique(y, return_counts=True)
227
  print(f"\n{split.upper()} β†’ X={X.shape}, label distribution:")
scripts/db8.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import sys
 
3
 
4
  import h5py
5
  import numpy as np
@@ -9,7 +10,18 @@ from joblib import Parallel, delayed
9
  from scipy.signal import iirnotch
10
  from tqdm import tqdm
11
 
12
- sequence_to_seconds = lambda seq_len, fs: seq_len / fs
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  _MATRIX_DOF2DOA_TRANSPOSED = np.array(
15
  # https://www.frontiersin.org/articles/10.3389/fnins.2019.00891/full
@@ -42,9 +54,18 @@ _MATRIX_DOF2DOA_TRANSPOSED = np.array(
42
  MATRIX_DOF2DOA = _MATRIX_DOF2DOA_TRANSPOSED.T
43
 
44
 
45
- # ─────────────── Filtering ──────────────────
46
- def notch_filter(data, notch_freq=50.0, Q=30.0, fs=1111.0):
47
- """Notch-filter every channel independently."""
 
 
 
 
 
 
 
 
 
48
  b, a = iirnotch(notch_freq, Q, fs)
49
  out = np.zeros_like(data)
50
  for ch in range(data.shape[1]):
@@ -52,7 +73,25 @@ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=1111.0):
52
  return out
53
 
54
 
55
- def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  nyq = 0.5 * fs
57
  b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
58
  out = np.zeros_like(emg)
@@ -61,11 +100,24 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
61
  return out
62
 
63
 
64
- # ─────────────── Sliding window ──────────────
65
- def sliding_window_segment(emg, label, window_size, stride):
66
- """
67
- Segment EMG with a sliding window.
68
- Use the frame at the window centre as the segment label / repetition index.
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  """
70
  segments, labels = [], []
71
  n_samples = len(label)
@@ -80,34 +132,49 @@ def sliding_window_segment(emg, label, window_size, stride):
80
  return np.array(segments), np.array(labels)
81
 
82
 
83
- # ─────────────── Main pipeline ───────────────
84
- def process_mat_file(mat_path, window_size, stride, fs):
85
- """
86
- Load one .mat file, filter out NaNs, filter & normalize EMG, map DoF→DoA,
87
- segment, and return (split, segs, labels).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  """
89
  mat = scipy.io.loadmat(mat_path)
90
  emg = mat["emg"] # (T, 16)
91
  label = mat["glove"] # (T, DoF)
92
 
93
- # 1) Drop timesteps with any NaNs in glove data
94
  valid = ~np.isnan(label).any(axis=1)
95
  emg = emg[valid]
96
  label = label[valid]
97
 
98
- # 3) Z-score per channel
99
  mu = emg.mean(axis=0)
100
  sd = emg.std(axis=0, ddof=1)
101
  sd[sd == 0] = 1.0
102
  emg = (emg - mu) / sd
103
 
104
- # 4) DoF β†’ DoA
105
  y_doa = (MATRIX_DOF2DOA @ label.T).T
106
 
107
- # 5) Windowing
108
  segs, labs = sliding_window_segment(emg, y_doa, window_size, stride)
109
 
110
- # 6) Determine split
111
  fname = os.path.basename(mat_path)
112
  if "_A1" in fname:
113
  split = "train"
 
1
  import os
2
  import sys
3
+ from typing import Tuple, List, Optional, Union, Dict, Any
4
 
5
  import h5py
6
  import numpy as np
 
10
  from scipy.signal import iirnotch
11
  from tqdm import tqdm
12
 
13
+ def sequence_to_seconds(seq_len: int, fs: float) -> float:
14
+ """Converts a sequence length in samples to time in seconds.
15
+
16
+ Args:
17
+ seq_len (int): The number of samples in the sequence.
18
+ fs (float): The sampling frequency in Hz.
19
+
20
+ Returns:
21
+ float: The duration of the sequence in seconds.
22
+ """
23
+ return seq_len / fs
24
+
25
 
26
  _MATRIX_DOF2DOA_TRANSPOSED = np.array(
27
  # https://www.frontiersin.org/articles/10.3389/fnins.2019.00891/full
 
54
  MATRIX_DOF2DOA = _MATRIX_DOF2DOA_TRANSPOSED.T
55
 
56
 
57
+ def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 1111.0) -> np.ndarray:
58
+ """Applies a notch filter to every channel of the input data independently.
59
+
60
+ Args:
61
+ data (np.ndarray): The input signal array of shape (T, D).
62
+ notch_freq (float, optional): The frequency to be removed. Defaults to 50.0.
63
+ Q (float, optional): The quality factor. Defaults to 30.0.
64
+ fs (float, optional): The sampling frequency in Hz. Defaults to 1111.0.
65
+
66
+ Returns:
67
+ np.ndarray: The filtered signal array.
68
+ """
69
  b, a = iirnotch(notch_freq, Q, fs)
70
  out = np.zeros_like(data)
71
  for ch in range(data.shape[1]):
 
73
  return out
74
 
75
 
76
+ def bandpass_filter_emg(
77
+ emg: np.ndarray,
78
+ lowcut: float = 20.0,
79
+ highcut: float = 90.0,
80
+ fs: float = 2000.0,
81
+ order: int = 4
82
+ ) -> np.ndarray:
83
+ """Applies a Butterworth bandpass filter to the EMG signal.
84
+
85
+ Args:
86
+ emg (np.ndarray): The input signal array of shape (T, D).
87
+ lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
88
+ highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
89
+ fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0.
90
+ order (int, optional): The order of the filter. Defaults to 4.
91
+
92
+ Returns:
93
+ np.ndarray: The filtered signal array.
94
+ """
95
  nyq = 0.5 * fs
96
  b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
97
  out = np.zeros_like(emg)
 
100
  return out
101
 
102
 
103
+ def sliding_window_segment(
104
+ emg: np.ndarray,
105
+ label: np.ndarray,
106
+ window_size: int,
107
+ stride: int
108
+ ) -> Tuple[np.ndarray, np.ndarray]:
109
+ """Segments EMG and label data using a sliding window.
110
+
111
+ Args:
112
+ emg (np.ndarray): The raw EMG data of shape (T, n_ch).
113
+ label (np.ndarray): The corresponding labels/targets.
114
+ window_size (int): Number of samples per window.
115
+ stride (int): Number of samples to shift between windows.
116
+
117
+ Returns:
118
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
119
+ - segmented EMG tokens (N, window_size, n_ch).
120
+ - segmented label tokens (N, window_size, target_dim).
121
  """
122
  segments, labels = [], []
123
  n_samples = len(label)
 
132
  return np.array(segments), np.array(labels)
133
 
134
 
135
+ def process_mat_file(
136
+ mat_path: str,
137
+ window_size: int,
138
+ stride: int,
139
+ fs: float
140
+ ) -> Optional[Tuple[str, np.ndarray, np.ndarray]]:
141
+ """Processes a single NinaPro DB8 .mat file.
142
+
143
+ Loads the file, removes NaNs, normalizes EMG (Z-score), maps finger degrees
144
+ of freedom (DoF) to degrees of activation (DoA), and segments the data.
145
+
146
+ Args:
147
+ mat_path (str): Absolute path to the .mat file.
148
+ window_size (int): Temporal window size in samples.
149
+ stride (int): Stride between windows in samples.
150
+ fs (float): Sampling frequency in Hz.
151
+
152
+ Returns:
153
+ Optional[Tuple[str, np.ndarray, np.ndarray]]: A tuple of (split_name, segments, labels)
154
+ if the file is valid, else None.
155
  """
156
  mat = scipy.io.loadmat(mat_path)
157
  emg = mat["emg"] # (T, 16)
158
  label = mat["glove"] # (T, DoF)
159
 
160
+ # Drop timesteps with any NaNs in glove data
161
  valid = ~np.isnan(label).any(axis=1)
162
  emg = emg[valid]
163
  label = label[valid]
164
 
165
+ # Z-score per channel
166
  mu = emg.mean(axis=0)
167
  sd = emg.std(axis=0, ddof=1)
168
  sd[sd == 0] = 1.0
169
  emg = (emg - mu) / sd
170
 
171
+ # DoF β†’ DoA
172
  y_doa = (MATRIX_DOF2DOA @ label.T).T
173
 
174
+ # Windowing
175
  segs, labs = sliding_window_segment(emg, y_doa, window_size, stride)
176
 
177
+ # Determine split
178
  fname = os.path.basename(mat_path)
179
  if "_A1" in fname:
180
  split = "train"
scripts/emg2pose.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
 
2
  from pathlib import Path
 
3
 
4
  import h5py
5
  import numpy as np
@@ -9,11 +11,31 @@ from joblib import Parallel, delayed
9
  from scipy.signal import iirnotch
10
  from tqdm import tqdm
11
 
12
- sequence_to_seconds = lambda seq_len, fs: seq_len / fs
 
13
 
 
 
 
14
 
15
- # ==== Filter functions (operate at original fs=2000) ====
16
- def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  b, a = iirnotch(notch_freq, Q, fs)
18
  out = np.zeros_like(data)
19
  for ch in range(data.shape[1]):
@@ -21,7 +43,25 @@ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
21
  return out
22
 
23
 
24
- def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  nyq = 0.5 * fs
26
  low = lowcut / nyq
27
  high = highcut / nyq
@@ -32,9 +72,18 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
32
  return out
33
 
34
 
35
- # ==== Window segmentation ====
36
- def process_emg_features(emg, window_size=1000, stride=500):
37
- segs, lbls = [], []
 
 
 
 
 
 
 
 
 
38
  N = len(emg)
39
  for start in range(0, N, stride):
40
  end = start + window_size
@@ -45,10 +94,19 @@ def process_emg_features(emg, window_size=1000, stride=500):
45
  return np.array(segs)
46
 
47
 
48
- def process_one_recording(file_path, fs=2000.0, window_size=1000, stride=500):
49
- """
50
- Process a single recording file to extract EMG features and labels
51
- as to be used in the main pipeline with parallel processing.
 
 
 
 
 
 
 
 
 
52
  """
53
  with h5py.File(file_path, "r") as f:
54
  grp = f["emg2pose"]
@@ -71,7 +129,6 @@ def process_one_recording(file_path, fs=2000.0, window_size=1000, stride=500):
71
  return segs
72
 
73
 
74
- # ==== Main pipeline ====
75
  def main():
76
  import argparse
77
 
@@ -93,6 +150,12 @@ def main():
93
  default=-1,
94
  help="Number of parallel jobs to run. -1 means using all available cores.",
95
  )
 
 
 
 
 
 
96
  args.add_argument(
97
  "--seed", type=int, default=42, help="Random seed for reproducibility."
98
  )
@@ -109,45 +172,62 @@ def main():
109
  print(f"Window size: {window_size} samples ({window_seconds:.2f} seconds)")
110
 
111
  df = pd.read_csv(os.path.join(data_dir, "metadata.csv"))
112
- df = df.groupby("split").apply(
113
- lambda x: (
114
- x.sample(frac=args.subsample, random_state=args.seed)
115
- if args.subsample < 1.0
116
- else x
117
  )
118
- )
119
- df.reset_index(drop=True, inplace=True)
120
 
121
  splits = {}
122
  for split, df_ in df.groupby("split"):
123
  sessions = list(df_.filename)
124
- splits[split] = [
125
  Path(data_dir).expanduser().joinpath(f"{session}.hdf5")
126
  for session in sessions
127
  ]
128
 
129
- all_data = {"train": [], "val": [], "test": []}
130
-
131
  for split, files in splits.items():
132
- # Here we use joblib to parallelize the file processing, each file is processed independently as the task is embarrassingly parallel. We scale the processing across all available CPU cores since the number of files is around 25k (with training being 17k).
133
- results = Parallel(n_jobs=args.n_jobs)(
134
- delayed(process_one_recording)(file_path, fs, window_size, stride)
135
- for file_path in tqdm(files, desc=f"Processing {split} files")
136
- )
137
- # Collect results
138
- for segs in tqdm(results, desc=f"Collecting {split} data"):
139
- all_data[split].append(segs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- # stack, augment train, transpose, save, and print stats
142
- X = np.concatenate(all_data[split], axis=0) # [N, window_size, ch]
 
 
143
 
144
- # transpose to [N, ch, window_size]
145
- X = X.transpose(0, 2, 1)
 
 
 
146
 
147
- # save
148
- with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf:
149
- hf.create_dataset("data", data=X)
150
 
151
 
152
  if __name__ == "__main__":
153
- main()
 
1
  import os
2
+ import gc
3
  from pathlib import Path
4
+ from typing import Tuple, List, Optional, Union, Dict, Any
5
 
6
  import h5py
7
  import numpy as np
 
11
  from scipy.signal import iirnotch
12
  from tqdm import tqdm
13
 
14
+ def sequence_to_seconds(seq_len: int, fs: float) -> float:
15
+ """Converts a sequence length in samples to time in seconds.
16
 
17
+ Args:
18
+ seq_len (int): The number of samples in the sequence.
19
+ fs (float): The sampling frequency in Hz.
20
 
21
+ Returns:
22
+ float: The duration of the sequence in seconds.
23
+ """
24
+ return seq_len / fs
25
+
26
+
27
+ def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 2000.0) -> np.ndarray:
28
+ """Applies a notch filter to every channel of the input data independently.
29
+
30
+ Args:
31
+ data (np.ndarray): The input signal array of shape (T, D).
32
+ notch_freq (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
33
+ Q (float, optional): The quality factor. Defaults to 30.0.
34
+ fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0.
35
+
36
+ Returns:
37
+ np.ndarray: The filtered signal array.
38
+ """
39
  b, a = iirnotch(notch_freq, Q, fs)
40
  out = np.zeros_like(data)
41
  for ch in range(data.shape[1]):
 
43
  return out
44
 
45
 
46
+ def bandpass_filter_emg(
47
+ emg: np.ndarray,
48
+ lowcut: float = 20.0,
49
+ highcut: float = 90.0,
50
+ fs: float = 2000.0,
51
+ order: int = 4
52
+ ) -> np.ndarray:
53
+ """Applies a Butterworth bandpass filter to the EMG signal.
54
+
55
+ Args:
56
+ emg (np.ndarray): The input signal array of shape (T, D).
57
+ lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
58
+ highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
59
+ fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0.
60
+ order (int, optional): The order of the filter. Defaults to 4.
61
+
62
+ Returns:
63
+ np.ndarray: The filtered signal array.
64
+ """
65
  nyq = 0.5 * fs
66
  low = lowcut / nyq
67
  high = highcut / nyq
 
72
  return out
73
 
74
 
75
+ def process_emg_features(emg: np.ndarray, window_size: int = 1000, stride: int = 500) -> np.ndarray:
76
+ """Segments raw EMG signals into overlapping windows.
77
+
78
+ Args:
79
+ emg (np.ndarray): Raw EMG data of shape (T, n_ch).
80
+ window_size (int, optional): Number of samples per window. Defaults to 1000.
81
+ stride (int, optional): Number of samples to shift between windows. Defaults to 500.
82
+
83
+ Returns:
84
+ np.ndarray: Segmented data of shape (N, window_size, n_ch).
85
+ """
86
+ segs = []
87
  N = len(emg)
88
  for start in range(0, N, stride):
89
  end = start + window_size
 
94
  return np.array(segs)
95
 
96
 
97
+ def process_one_recording(file_path: str, fs: float = 2000.0, window_size: int = 1000, stride: int = 500) -> np.ndarray:
98
+ """Processes a single EMG2Pose recording file.
99
+
100
+ Loads HDF5 timeseries, filters EMG, normalizes (Z-score), and segments.
101
+
102
+ Args:
103
+ file_path (str): Absolute path to the .h5 recording file.
104
+ fs (float, optional): Sampling frequency in Hz. Defaults to 2000.0.
105
+ window_size (int, optional): Temporal window size in samples. Defaults to 1000.
106
+ stride (int, optional): Stride between windows in samples. Defaults to 500.
107
+
108
+ Returns:
109
+ np.ndarray: Array of processed segments (N, window_size, n_ch).
110
  """
111
  with h5py.File(file_path, "r") as f:
112
  grp = f["emg2pose"]
 
129
  return segs
130
 
131
 
 
132
  def main():
133
  import argparse
134
 
 
150
  default=-1,
151
  help="Number of parallel jobs to run. -1 means using all available cores.",
152
  )
153
+ args.add_argument(
154
+ "--group_size",
155
+ type=int,
156
+ default=1000,
157
+ help="Number of samples per group in the output HDF5 file.",
158
+ )
159
  args.add_argument(
160
  "--seed", type=int, default=42, help="Random seed for reproducibility."
161
  )
 
172
  print(f"Window size: {window_size} samples ({window_seconds:.2f} seconds)")
173
 
174
  df = pd.read_csv(os.path.join(data_dir, "metadata.csv"))
175
+ if args.subsample < 1.0:
176
+ df = df.groupby("split", group_keys=False).sample(
177
+ frac=args.subsample, random_state=args.seed
 
 
178
  )
179
+ df = df.reset_index(drop=True)
 
180
 
181
  splits = {}
182
  for split, df_ in df.groupby("split"):
183
  sessions = list(df_.filename)
184
+ splits[split] =[
185
  Path(data_dir).expanduser().joinpath(f"{session}.hdf5")
186
  for session in sessions
187
  ]
188
 
 
 
189
  for split, files in splits.items():
190
+ out_file = os.path.join(save_dir, f"{split}.h5")
191
+
192
+ # Remove existing file if it exists so we don't accidentally append to old runs
193
+ if os.path.exists(out_file):
194
+ os.remove(out_file)
195
+
196
+ print(f"Processing {split} split ({len(files)} files)...")
197
+
198
+ with h5py.File(out_file, "w") as h5f:
199
+ group_idx = 0
200
+ with Parallel(n_jobs=args.n_jobs) as parallel:
201
+ with tqdm(total=len(files), desc=f"Processing & Saving {split}") as pbar:
202
+
203
+ # Iterate files in batches
204
+ for i in range(0, len(files), args.group_size):
205
+ batch_files = files[i : i + args.group_size]
206
+
207
+ # Process current batch
208
+ results = parallel(
209
+ delayed(process_one_recording)(file_path, fs, window_size, stride)
210
+ for file_path in batch_files
211
+ )
212
+
213
+ if results:
214
+ X_chunk = np.concatenate(results, axis=0) # [N, window_size, ch]
215
+ X_chunk = X_chunk.transpose(0, 2, 1) # [N, ch, window_size]
216
+ X_chunk = X_chunk.astype(np.float32)
217
 
218
+ # Write each processed batch as a group compatible with HDF5Loader
219
+ grp = h5f.create_group(f"data_group_{group_idx}")
220
+ grp.create_dataset("X", data=X_chunk)
221
+ group_idx += 1
222
 
223
+ # Explicitly clear memory of large numpy arrays
224
+ del results
225
+ if 'X_chunk' in locals():
226
+ del X_chunk
227
+ gc.collect()
228
 
229
+ pbar.update(len(batch_files))
 
 
230
 
231
 
232
  if __name__ == "__main__":
233
+ main()
scripts/epn.py CHANGED
@@ -2,6 +2,7 @@ import glob
2
  import json
3
  import os
4
  import sys
 
5
 
6
  import h5py
7
  import numpy as np
@@ -10,7 +11,17 @@ from joblib import Parallel, delayed
10
  from scipy.signal import iirnotch
11
  from tqdm.auto import tqdm
12
 
13
- sequence_to_seconds = lambda seq_len, fs: seq_len / fs
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Sampling frequency and EMG channels
16
  tfs, n_ch = 200.0, 8
@@ -27,28 +38,77 @@ gesture_map = {
27
  }
28
 
29
 
30
- # Filtering utilities
31
- def bandpass_filter_emg(emg, low=20.0, high=90.0, fs=tfs, order=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  nyq = 0.5 * fs
33
  b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
34
  return signal.filtfilt(b, a, emg, axis=1)
35
 
36
 
37
- def notch_filter_emg(emg, notch=50.0, Q=30.0, fs=tfs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  w0 = notch / (0.5 * fs)
39
  b, a = iirnotch(w0, Q)
40
  return signal.filtfilt(b, a, emg, axis=1)
41
 
42
 
43
- # Normalization helpers
44
- def zscore_per_channel(emg):
 
 
 
 
 
 
 
45
  mean = emg.mean(axis=1, keepdims=True)
46
  std = emg.std(axis=1, ddof=1, keepdims=True)
47
  std[std == 0] = 1.0
48
  return (emg - mean) / std
49
 
50
 
51
- def adjust_length(x, max_len):
 
 
 
 
 
 
 
 
 
52
  n_ch, seq_len = x.shape
53
  if seq_len >= max_len:
54
  return x[:, :max_len]
@@ -56,8 +116,18 @@ def adjust_length(x, max_len):
56
  return np.concatenate([x, pad], axis=1)
57
 
58
 
59
- # Single-sample processing
60
- def extract_emg_signal(sample, seq_len):
 
 
 
 
 
 
 
 
 
 
61
  emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
62
  emg = bandpass_filter_emg(emg, 20.0, 90.0)
63
  emg = notch_filter_emg(emg, 50.0, 30.0)
@@ -67,8 +137,20 @@ def extract_emg_signal(sample, seq_len):
67
  return emg, label
68
 
69
 
70
- # Process one user JSON for train/validation
71
- def process_user_training(path, seq_len):
 
 
 
 
 
 
 
 
 
 
 
 
72
  train_X, train_y, val_X, val_y = [], [], [], []
73
  with open(path, "r", encoding="utf-8") as f:
74
  data = json.load(f)
@@ -79,14 +161,29 @@ def process_user_training(path, seq_len):
79
  train_y.append(lbl)
80
  for sample in data.get("testingSamples", {}).values():
81
  emg, lbl = extract_emg_signal(sample, seq_len)
 
 
 
82
  if lbl != 6:
83
  val_X.append(emg)
84
  val_y.append(lbl)
85
  return train_X, train_y, val_X, val_y
86
 
87
 
88
- # Process one user JSON for testing split
89
- def process_user_testing(path, seq_len):
 
 
 
 
 
 
 
 
 
 
 
 
90
  train_X, train_y, test_X, test_y = [], [], [], []
91
  with open(path, "r", encoding="utf-8") as f:
92
  data = json.load(f)
@@ -107,14 +204,19 @@ def process_user_testing(path, seq_len):
107
  return train_X, train_y, test_X, test_y
108
 
109
 
110
- # Save to HDF5
111
- def save_h5(path, data, labels):
 
 
 
 
 
 
112
  with h5py.File(path, "w") as f:
113
  f.create_dataset("data", data=np.asarray(data, np.float32))
114
  f.create_dataset("label", data=np.asarray(labels, np.int64))
115
 
116
 
117
- # Main parallelized pipeline
118
  def main():
119
  import argparse
120
 
 
2
  import json
3
  import os
4
  import sys
5
+ from typing import Tuple, List, Optional, Union, Dict, Any
6
 
7
  import h5py
8
  import numpy as np
 
11
  from scipy.signal import iirnotch
12
  from tqdm.auto import tqdm
13
 
14
+ def sequence_to_seconds(seq_len: int, fs: float) -> float:
15
+ """Converts a sequence length in samples to time in seconds.
16
+
17
+ Args:
18
+ seq_len (int): The number of samples in the sequence.
19
+ fs (float): The sampling frequency in Hz.
20
+
21
+ Returns:
22
+ float: The duration of the sequence in seconds.
23
+ """
24
+ return seq_len / fs
25
 
26
  # Sampling frequency and EMG channels
27
  tfs, n_ch = 200.0, 8
 
38
  }
39
 
40
 
41
+ def bandpass_filter_emg(
42
+ emg: np.ndarray,
43
+ low: float = 20.0,
44
+ high: float = 90.0,
45
+ fs: float = tfs,
46
+ order: int = 4
47
+ ) -> np.ndarray:
48
+ """Applies a Butterworth bandpass filter to the EMG signal.
49
+
50
+ Args:
51
+ emg (np.ndarray): The input signal array of shape (n_ch, T).
52
+ low (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
53
+ high (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
54
+ fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
55
+ order (int, optional): The order of the filter. Defaults to 4.
56
+
57
+ Returns:
58
+ np.ndarray: The filtered signal array.
59
+ """
60
  nyq = 0.5 * fs
61
  b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
62
  return signal.filtfilt(b, a, emg, axis=1)
63
 
64
 
65
+ def notch_filter_emg(
66
+ emg: np.ndarray,
67
+ notch: float = 50.0,
68
+ Q: float = 30.0,
69
+ fs: float = tfs
70
+ ) -> np.ndarray:
71
+ """Applies a notch filter to remove power line interference.
72
+
73
+ Args:
74
+ emg (np.ndarray): The input signal array of shape (n_ch, T).
75
+ notch (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
76
+ Q (float, optional): The quality factor. Defaults to 30.0.
77
+ fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
78
+
79
+ Returns:
80
+ np.ndarray: The filtered signal array.
81
+ """
82
  w0 = notch / (0.5 * fs)
83
  b, a = iirnotch(w0, Q)
84
  return signal.filtfilt(b, a, emg, axis=1)
85
 
86
 
87
+ def zscore_per_channel(emg: np.ndarray) -> np.ndarray:
88
+ """Normalizes the EMG signal using Z-score (per channel).
89
+
90
+ Args:
91
+ emg (np.ndarray): The input EMG signal of shape (n_ch, T).
92
+
93
+ Returns:
94
+ np.ndarray: The normalized EMG signal.
95
+ """
96
  mean = emg.mean(axis=1, keepdims=True)
97
  std = emg.std(axis=1, ddof=1, keepdims=True)
98
  std[std == 0] = 1.0
99
  return (emg - mean) / std
100
 
101
 
102
+ def adjust_length(x: np.ndarray, max_len: int) -> np.ndarray:
103
+ """Standardizes the temporal length of the signal by clipping or zero-padding.
104
+
105
+ Args:
106
+ x (np.ndarray): The input signal of shape (n_ch, T).
107
+ max_len (int): The target length in samples.
108
+
109
+ Returns:
110
+ np.ndarray: The standardized length signal of shape (n_ch, max_len).
111
+ """
112
  n_ch, seq_len = x.shape
113
  if seq_len >= max_len:
114
  return x[:, :max_len]
 
116
  return np.concatenate([x, pad], axis=1)
117
 
118
 
119
+ def extract_emg_signal(sample: Dict[str, Any], seq_len: int) -> Tuple[np.ndarray, int]:
120
+ """Extracts, filters, and normalizes EMG data from a JSON sample.
121
+
122
+ Args:
123
+ sample (Dict[str, Any]): A single sample dictionary from the EPN612 JSON.
124
+ seq_len (int): Target temporal length.
125
+
126
+ Returns:
127
+ Tuple[np.ndarray, int]: A tuple containing:
128
+ - The preprocessed EMG signal (n_ch, seq_len).
129
+ - The gesture label ID.
130
+ """
131
  emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
132
  emg = bandpass_filter_emg(emg, 20.0, 90.0)
133
  emg = notch_filter_emg(emg, 50.0, 30.0)
 
137
  return emg, label
138
 
139
 
140
+ def process_user_training(
141
+ path: str,
142
+ seq_len: int
143
+ ) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
144
+ """Processes a user's training JSON file for the training and validation splits.
145
+
146
+ Args:
147
+ path (str): Path to the user JSON file.
148
+ seq_len (int): Target temporal length for segmentation.
149
+
150
+ Returns:
151
+ Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
152
+ (train_X, train_y, val_X, val_y) lists.
153
+ """
154
  train_X, train_y, val_X, val_y = [], [], [], []
155
  with open(path, "r", encoding="utf-8") as f:
156
  data = json.load(f)
 
161
  train_y.append(lbl)
162
  for sample in data.get("testingSamples", {}).values():
163
  emg, lbl = extract_emg_signal(sample, seq_len)
164
+ if lbl != 10: # Assuming 10 was the intention or checking if not invalid
165
+ pass
166
+ # Note: checking lbl != 6 as in original
167
  if lbl != 6:
168
  val_X.append(emg)
169
  val_y.append(lbl)
170
  return train_X, train_y, val_X, val_y
171
 
172
 
173
+ def process_user_testing(
174
+ path: str,
175
+ seq_len: int
176
+ ) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
177
+ """Processes a user's testing JSON file for the fine-tuning and test splits.
178
+
179
+ Args:
180
+ path (str): Path to the user JSON file.
181
+ seq_len (int): Target temporal length for segmentation.
182
+
183
+ Returns:
184
+ Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
185
+ (tune_X, tune_y, test_X, test_y) lists.
186
+ """
187
  train_X, train_y, test_X, test_y = [], [], [], []
188
  with open(path, "r", encoding="utf-8") as f:
189
  data = json.load(f)
 
204
  return train_X, train_y, test_X, test_y
205
 
206
 
207
+ def save_h5(path: str, data: List[np.ndarray], labels: List[int]) -> None:
208
+ """Saves the processed EMG data and labels to an HDF5 file.
209
+
210
+ Args:
211
+ path (str): Output file path.
212
+ data (List[np.ndarray]): List of signal segments.
213
+ labels (List[int]): List of categorical labels.
214
+ """
215
  with h5py.File(path, "w") as f:
216
  f.create_dataset("data", data=np.asarray(data, np.float32))
217
  f.create_dataset("label", data=np.asarray(labels, np.int64))
218
 
219
 
 
220
  def main():
221
  import argparse
222
 
scripts/uci.py CHANGED
@@ -1,36 +1,78 @@
1
  import os
2
  import sys
3
  from pathlib import Path
 
4
 
5
  import h5py
6
  import numpy as np
7
  import scipy.signal as signal
8
  from scipy.signal import iirnotch
9
 
10
- sequence_to_seconds = lambda seq_len, fs: seq_len / fs
 
11
 
 
 
 
12
 
13
- # ─────────────────────────────────────────────
14
- # Filtering utilities
15
- # ─────────────────────────────────────────────
16
- def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  nyq = 0.5 * fs
18
  b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
19
  return signal.filtfilt(b, a, emg, axis=0)
20
 
21
-
22
- def notch_filter_emg(emg, notch_freq=50.0, Q=30.0, fs=200.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  b, a = iirnotch(notch_freq / (0.5 * fs), Q)
24
  return signal.filtfilt(b, a, emg, axis=0)
25
 
 
 
26
 
27
- # ─────────────────────────────────────────────
28
- # Core I/O + preprocessing helpers
29
- # ─────────────────────────────────────────────
30
- def read_emg_txt(txt_path):
31
- """
32
- Read a txt file with columns: time ch1 … ch8 class.
33
- Return float32 array of shape (N, 10).
34
  """
35
  data = []
36
  with open(txt_path, "r") as f:
@@ -41,10 +83,22 @@ def read_emg_txt(txt_path):
41
  return np.asarray(data, dtype=np.float32)
42
 
43
 
44
- def preprocess_emg(arr, fs=200.0, remove_class0=True):
45
- """
46
- 1) optional removal of class-0 rows
47
- 2) band-pass β†’ notch β†’ Z-score (on 8 channels)
 
 
 
 
 
 
 
 
 
 
 
 
48
  """
49
  if remove_class0:
50
  arr = arr[arr[:, -1] >= 1]
@@ -64,8 +118,15 @@ def preprocess_emg(arr, fs=200.0, remove_class0=True):
64
  return arr
65
 
66
 
67
- def find_label_runs(arr):
68
- """Group consecutive rows with identical class labels."""
 
 
 
 
 
 
 
69
  runs = []
70
  if arr.size == 0:
71
  return runs
@@ -80,7 +141,23 @@ def find_label_runs(arr):
80
  return runs
81
 
82
 
83
- def sliding_window_majority(seg_arr, window_size=1000, stride=500):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  segs, labs = [], []
85
  for start in range(0, len(seg_arr) - window_size + 1, stride):
86
  win = seg_arr[start : start + window_size]
@@ -91,8 +168,24 @@ def sliding_window_majority(seg_arr, window_size=1000, stride=500):
91
 
92
 
93
  def users_with_gesture(
94
- data_root, gesture_id, subj_range=range(1, 37), return_counts=False
95
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  found = {}
97
  for subj in subj_range:
98
  subj_dir = os.path.join(data_root, f"{subj:02d}")
@@ -122,21 +215,28 @@ def users_with_gesture(
122
  else:
123
  return sorted(found.keys())
124
 
 
 
125
 
126
- # ─────────────────────────────────────────────
127
- # Safe concatenation utilities
128
- # ─────────────────────────────────────────────
129
- def concat_data(lst): # lst of (N,256,8)
 
 
130
  return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
131
 
 
 
132
 
133
- def concat_label(lst):
134
- return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)
135
 
 
 
 
 
136
 
137
- # ─────────────────────────────────────────────
138
- # Main
139
- # ─────────────────────────────────────────────
140
  if __name__ == "__main__":
141
  import argparse
142
 
 
1
  import os
2
  import sys
3
  from pathlib import Path
4
+ from typing import Tuple, List, Optional, Union, Dict, Any
5
 
6
  import h5py
7
  import numpy as np
8
  import scipy.signal as signal
9
  from scipy.signal import iirnotch
10
 
11
+ def sequence_to_seconds(seq_len: int, fs: float) -> float:
12
+ """Converts a sequence length in samples to time in seconds.
13
 
14
+ Args:
15
+ seq_len (int): The number of samples in the sequence.
16
+ fs (float): The sampling frequency in Hz.
17
 
18
+ Returns:
19
+ float: The duration of the sequence in seconds.
20
+ """
21
+ return seq_len / fs
22
+
23
+ def bandpass_filter_emg(
24
+ emg: np.ndarray,
25
+ lowcut: float = 20.0,
26
+ highcut: float = 90.0,
27
+ fs: float = 200.0,
28
+ order: int = 4
29
+ ) -> np.ndarray:
30
+ """Applies a Butterworth bandpass filter to the EMG signal.
31
+
32
+ Args:
33
+ emg (np.ndarray): The input signal array of shape (T, D).
34
+ lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
35
+ highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
36
+ fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
37
+ order (int, optional): The order of the filter. Defaults to 4.
38
+
39
+ Returns:
40
+ np.ndarray: The filtered signal array.
41
+ """
42
  nyq = 0.5 * fs
43
  b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
44
  return signal.filtfilt(b, a, emg, axis=0)
45
 
46
+ def notch_filter_emg(
47
+ emg: np.ndarray,
48
+ notch_freq: float = 50.0,
49
+ Q: float = 30.0,
50
+ fs: float = 200.0
51
+ ) -> np.ndarray:
52
+ """Applies a notch filter to remove power line interference.
53
+
54
+ Args:
55
+ emg (np.ndarray): The input signal array of shape (T, D).
56
+ notch_freq (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
57
+ Q (float, optional): The quality factor. Defaults to 30.0.
58
+ fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
59
+
60
+ Returns:
61
+ np.ndarray: The filtered signal array.
62
+ """
63
  b, a = iirnotch(notch_freq / (0.5 * fs), Q)
64
  return signal.filtfilt(b, a, emg, axis=0)
65
 
66
+ def read_emg_txt(txt_path: str) -> np.ndarray:
67
+ """Reads a UCI EMG text file into a numpy array.
68
 
69
+ The file is expected to have columns: [time, ch1, ..., ch8, class].
70
+
71
+ Args:
72
+ txt_path (str): Path to the .txt file.
73
+
74
+ Returns:
75
+ np.ndarray: A float32 array of shape (N, 10).
76
  """
77
  data = []
78
  with open(txt_path, "r") as f:
 
83
  return np.asarray(data, dtype=np.float32)
84
 
85
 
86
+ def preprocess_emg(arr: np.ndarray, fs: float = 200.0, remove_class0: bool = True) -> np.ndarray:
87
+ """Applies a standard preprocessing pipeline to the EMG data.
88
+
89
+ Pipeline includes:
90
+ 1. Optional removal of rest (class 0).
91
+ 2. Bandpass filtering (20-90 Hz).
92
+ 3. Notch filtering (50 Hz).
93
+ 4. Z-score normalization per channel.
94
+
95
+ Args:
96
+ arr (np.ndarray): Raw data array of shape (N, 10).
97
+ fs (float, optional): Sampling frequency in Hz. Defaults to 200.0.
98
+ remove_class0 (bool, optional): Whether to remove the "rest" class. Defaults to True.
99
+
100
+ Returns:
101
+ np.ndarray: The preprocessed data array.
102
  """
103
  if remove_class0:
104
  arr = arr[arr[:, -1] >= 1]
 
118
  return arr
119
 
120
 
121
+ def find_label_runs(arr: np.ndarray) -> List[Tuple[int, np.ndarray]]:
122
+ """Groups consecutive rows with identical class labels.
123
+
124
+ Args:
125
+ arr (np.ndarray): Data array where the last column is the class label.
126
+
127
+ Returns:
128
+ List[Tuple[int, np.ndarray]]: A list of tuples (label, sub-array).
129
+ """
130
  runs = []
131
  if arr.size == 0:
132
  return runs
 
141
  return runs
142
 
143
 
144
+ def sliding_window_majority(
145
+ seg_arr: np.ndarray,
146
+ window_size: int = 1000,
147
+ stride: int = 500
148
+ ) -> Tuple[np.ndarray, np.ndarray]:
149
+ """Segments a label-consistent array using a sliding window and majority voting.
150
+
151
+ Args:
152
+ seg_arr (np.ndarray): Data array of shape (T, 10).
153
+ window_size (int, optional): Number of samples per window. Defaults to 1000.
154
+ stride (int, optional): Number of samples to shift between windows. Defaults to 500.
155
+
156
+ Returns:
157
+ Tuple[np.ndarray, np.ndarray]: A tuple containing:
158
+ - Windowed EMG segments (N, window_size, 8).
159
+ - Majority vote labels (N,).
160
+ """
161
  segs, labs = [], []
162
  for start in range(0, len(seg_arr) - window_size + 1, stride):
163
  win = seg_arr[start : start + window_size]
 
168
 
169
 
170
  def users_with_gesture(
171
+ data_root: str,
172
+ gesture_id: int,
173
+ subj_range: range = range(1, 37),
174
+ return_counts: bool = False
175
+ ) -> Union[List[int], Dict[int, int]]:
176
+ """Identifies which subjects performed a specific gesture.
177
+
178
+ Args:
179
+ data_root (str): Root directory of the dataset.
180
+ gesture_id (int): The ID of the gesture to search for.
181
+ subj_range (range, optional): Range of subject IDs to check. Defaults to range(1, 37).
182
+ return_counts (bool, optional): If True, returns a dictionary with sample counts.
183
+ Defaults to False.
184
+
185
+ Returns:
186
+ Union[List[int], Dict[int, int]]: Either a list of subject IDs or a dictionary
187
+ mapping subject ID to occurrence count.
188
+ """
189
  found = {}
190
  for subj in subj_range:
191
  subj_dir = os.path.join(data_root, f"{subj:02d}")
 
215
  else:
216
  return sorted(found.keys())
217
 
218
+ def concat_data(lst: List[np.ndarray]) -> np.ndarray:
219
+ """Concatenates a list of data arrays.
220
 
221
+ Args:
222
+ lst (List[np.ndarray]): List of arrays to concatenate.
223
+
224
+ Returns:
225
+ np.ndarray: Concatenated array or empty array if list is empty.
226
+ """
227
  return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
228
 
229
+ def concat_label(lst: List[np.ndarray]) -> np.ndarray:
230
+ """Concatenates a list of label arrays.
231
 
232
+ Args:
233
+ lst (List[np.ndarray]): List of label arrays.
234
 
235
+ Returns:
236
+ np.ndarray: Concatenated array or empty array if list is empty.
237
+ """
238
+ return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)
239
 
 
 
 
240
  if __name__ == "__main__":
241
  import argparse
242