Commit Β·
45d17fb
1
Parent(s): 6d283f0
refactor: EMG processing scripts and documentation
Browse files- README.md +47 -280
- scripts/README.md +26 -125
- scripts/db5.py +122 -13
- scripts/db6.py +108 -60
- scripts/db7.py +108 -58
- scripts/db8.py +87 -20
- scripts/emg2pose.py +118 -38
- scripts/epn.py +118 -16
- scripts/uci.py +132 -32
README.md
CHANGED
|
@@ -103,310 +103,77 @@ tags:
|
|
| 103 |
|
| 104 |
<div align="center">
|
| 105 |
<img src="https://raw.githubusercontent.com/MatteoFasulo/BioFoundation/refs/heads/TinyMyo/docs/model/logo/TinyMyo_logo.png" alt="TinyMyo Logo" width="400" />
|
| 106 |
-
<h1>TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge</h1>
|
| 107 |
-
</div>
|
| 108 |
-
<p align="center">
|
| 109 |
-
<a href="https://github.com/pulp-bio/BioFoundation">
|
| 110 |
-
|
| 111 |
-
</a>
|
| 112 |
-
<a href="https://creativecommons.org/licenses/by-nd/4.0/">
|
| 113 |
-
<img src="https://img.shields.io/badge/License-CC_BY--ND_4.0-lightgrey.svg" alt="License">
|
| 114 |
-
</a>
|
| 115 |
-
<a href="https://arxiv.org/abs/2512.15729">
|
| 116 |
-
<img src="https://img.shields.io/badge/arXiv-2512.15729-b31b1b.svg" alt="Paper">
|
| 117 |
-
</a>
|
| 118 |
</p>
|
| 119 |
|
| 120 |
-
**TinyMyo** is a **3.6M-parameter** Transformer
|
| 121 |
-
It is pretrained on >480 GB of EMG data and optimized for **ultra-low-power, real-time deployment**, including **microcontrollers (GAP9)** where it achieves an inference time of **0.785 s**, energy of **44.91 mJ** and power envelope of **57.18 mW**.
|
| 122 |
-
|
| 123 |
-
TinyMyo is built for **broad generalization** across datasets, sensor configurations, movement tasks, subjects, and domains (gesture, kinematics, speech).
|
| 124 |
-
|
| 125 |
-
---
|
| 126 |
-
|
| 127 |
-
# π License & Usage (Model Weights)
|
| 128 |
-
|
| 129 |
-
The released TinyMyo weights are licensed under **CC BY-ND 4.0**.
|
| 130 |
-
This summary is not legal adviceβplease read the full license.
|
| 131 |
-
|
| 132 |
-
### β
You may
|
| 133 |
-
|
| 134 |
-
* **Use** and **redistribute** the **unmodified** TinyMyo weights (including commercially) **with attribution**.
|
| 135 |
-
* **Fine-tune/modify internally** for research or production without redistributing modified weights.
|
| 136 |
-
* **Publish code, configs, evaluations, and papers** using TinyMyo.
|
| 137 |
-
|
| 138 |
-
### π« You may not
|
| 139 |
-
|
| 140 |
-
* **Share or host modified weights** in any form (including LoRA/adapter deltas, pruned/quantized models).
|
| 141 |
-
* **Claim endorsement** from the TinyMyo authors without permission.
|
| 142 |
-
* **Use the TinyMyo name** for derivative models.
|
| 143 |
-
|
| 144 |
-
### π€ Contributing Improvements
|
| 145 |
-
|
| 146 |
-
To upstream improvements, submit a **PR** to the
|
| 147 |
-
**[BioFoundation repository](https://github.com/pulp-bio/BioFoundation)** with:
|
| 148 |
-
|
| 149 |
-
1. Full reproducibility artifacts (configs, logs, seeds, environment).
|
| 150 |
-
2. Evaluation on standard protocols (e.g., DB5, EPN-612, UCI EMG, DB8, Silent Speech).
|
| 151 |
-
3. Comparison to TinyMyoβs reported metrics.
|
| 152 |
-
|
| 153 |
-
Approved PRs will be retrained and released as **official TinyMyo** checkpoints under CC BY-ND.
|
| 154 |
-
|
| 155 |
-
---
|
| 156 |
-
|
| 157 |
-
# π 1. Default Input & Preprocessing
|
| 158 |
-
|
| 159 |
-
Unless specified otherwise, TinyMyo expects:
|
| 160 |
-
|
| 161 |
-
* **Channels:** 16
|
| 162 |
-
* **Sampling rate:** 2000 Hz
|
| 163 |
-
* **Segment length:** 1000 samples (0.5 s)
|
| 164 |
-
* **Windowing:** 50% overlap (pretraining)
|
| 165 |
-
* **Preprocessing:**
|
| 166 |
-
|
| 167 |
-
* 4th-order **20β450 Hz bandpass**
|
| 168 |
-
* **50 Hz notch filter**
|
| 169 |
-
* **Minβmax normalization** (pretraining)
|
| 170 |
-
* **Z-score normalization** (downstream)
|
| 171 |
-
|
| 172 |
-
Datasets with <16 channels are **zero-padded (pretraining only)**.
|
| 173 |
-
|
| 174 |
-
---
|
| 175 |
-
|
| 176 |
-
# π¬ 2. Pretraining Overview
|
| 177 |
-
|
| 178 |
-
TinyMyo is pretrained via masked reconstruction on **three large-scale EMG datasets**:
|
| 179 |
-
|
| 180 |
-
| Dataset | Subjects | fs | Channels | Size |
|
| 181 |
-
| ----------- | -------- | ------- | -------- | ------- |
|
| 182 |
-
| Ninapro DB6 | 10 | 2000 Hz | 14 | 20.3 GB |
|
| 183 |
-
| Ninapro DB7 | 22 | 2000 Hz | 12 | 30.9 GB |
|
| 184 |
-
| EMG2Pose | 192 | 2000 Hz | 16 | 431 GB |
|
| 185 |
-
|
| 186 |
-
## Tokenization: Channel-Independent Patches
|
| 187 |
-
|
| 188 |
-
Unlike EEG FMs that mix channels early, TinyMyo uses **per-channel patching**:
|
| 189 |
-
|
| 190 |
-
* Patch length: **20 samples**
|
| 191 |
-
* Patch stride: **20 samples**
|
| 192 |
-
* Tokens/channel: **50**
|
| 193 |
-
* Total seq length: **800 tokens** (16 x 50)
|
| 194 |
-
* Positional encoding: **RoPE**
|
| 195 |
-
|
| 196 |
-
This preserves electrode-specific structure while allowing attention to learn cross-channel relationships.
|
| 197 |
-
|
| 198 |
-
## Transformer Encoder
|
| 199 |
-
|
| 200 |
-
* **8 layers**, **3 heads**
|
| 201 |
-
* Embedding dim: **192**
|
| 202 |
-
* Pre-LayerNorm
|
| 203 |
-
* Dropout & drop-path: **0.1**
|
| 204 |
-
|
| 205 |
-
## Lightweight Decoder
|
| 206 |
-
|
| 207 |
-
A **single linear layer** (~3.9k params) reconstructs masked patches.
|
| 208 |
-
Following SimMIM, this forces the encoder to learn robust latent structure.
|
| 209 |
-
|
| 210 |
-
## Masking Objective
|
| 211 |
-
|
| 212 |
-
* **50% random masking** with a learnable `[MASK]` token
|
| 213 |
-
* Loss: **Smooth L1** with small penalty on visible patches
|
| 214 |
-
$$
|
| 215 |
-
\mathcal{L} = \mathcal{L}*{\text{masked}} + 0.1,\mathcal{L}*{\text{visible}}
|
| 216 |
-
$$
|
| 217 |
-
|
| 218 |
-
## Training Setup
|
| 219 |
-
|
| 220 |
-
* Optimizer: **AdamW** (Ξ²=(0.9,0.98), wd=0.01)
|
| 221 |
-
* LR: **1e-4** with cosine decay
|
| 222 |
-
* Batch size: **512** (with grad accumulation)
|
| 223 |
-
* Epochs: **50**, warm-up: 10
|
| 224 |
-
* Hardware: **4Γ NVIDIA GH200 GPUs**
|
| 225 |
-
|
| 226 |
-
---
|
| 227 |
-
|
| 228 |
-
# π§ 3. Architecture Summary
|
| 229 |
-
|
| 230 |
-
### Model Variant
|
| 231 |
-
|
| 232 |
-
| Variant | Params | (Layers, Dim) |
|
| 233 |
-
| ------- | -------- | ------------- |
|
| 234 |
-
| TinyMyo | **3.6M** | (8, 192) |
|
| 235 |
-
|
| 236 |
-
---
|
| 237 |
-
|
| 238 |
-
# π― 4. Downstream Tasks
|
| 239 |
-
|
| 240 |
-
TinyMyo generalizes across **gesture classification**, **kinematic regression**, and **speech EMG**βwith state-of-the-art or competitive results.
|
| 241 |
-
|
| 242 |
-
---
|
| 243 |
-
|
| 244 |
-
## 4.1 Hand Gesture Classification
|
| 245 |
-
|
| 246 |
-
Evaluated on:
|
| 247 |
-
|
| 248 |
-
* **Ninapro DB5** (52 classes, 10 subjects, 200 Hz)
|
| 249 |
-
* **EPN-612** (5 classes, 612 subjects, 200 Hz)
|
| 250 |
-
* **UCI EMG** (6 classes, 36 subjects, 200 Hz)
|
| 251 |
-
|
| 252 |
-
### Preprocessing
|
| 253 |
-
|
| 254 |
-
* EMG filtering: **20β90 Hz bandpass + 50 Hz notch**
|
| 255 |
-
* Window sizes:
|
| 256 |
-
|
| 257 |
-
* **200 samples** (1 sec, best for DB5)
|
| 258 |
-
* **1000 samples** (5 sec, best for EPN, UCI)
|
| 259 |
-
|
| 260 |
-
### Linear Classification Head
|
| 261 |
-
|
| 262 |
-
* Input: **C Γ 192**
|
| 263 |
-
* Params: **<40k**
|
| 264 |
-
|
| 265 |
-
### Performance (Fine-tuned)
|
| 266 |
-
|
| 267 |
-
| Dataset | Metric | Result |
|
| 268 |
-
| ------------------------ | ------ | ----------------- |
|
| 269 |
-
| **Ninapro DB5** (1 sec) | Acc | **89.41 Β± 0.16%** |
|
| 270 |
-
| **EPN-612** (5 sec) | Acc | **96.74 Β± 0.09%** |
|
| 271 |
-
| **UCI EMG** (5 sec) | Acc | **97.56 Β± 0.32%** |
|
| 272 |
-
|
| 273 |
-
TinyMyo achieves **new state-of-the-art** on DB5, EPN-612, and UCI.
|
| 274 |
-
|
| 275 |
-
---
|
| 276 |
-
|
| 277 |
-
## 4.2 Hand Kinematic Regression (Ninapro DB8)
|
| 278 |
-
|
| 279 |
-
* Predict **5 joint angles**
|
| 280 |
-
* Windows: **100 ms** or **500 ms**
|
| 281 |
-
* Normalization: z-score only
|
| 282 |
-
|
| 283 |
-
### Regression Head (~788k params)
|
| 284 |
-
|
| 285 |
-
* Depthwise + pointwise convs
|
| 286 |
-
* Upsampling
|
| 287 |
-
* Global average pooling
|
| 288 |
-
* Linear projection to 5 outputs
|
| 289 |
-
|
| 290 |
-
### Performance
|
| 291 |
-
|
| 292 |
-
* **MAE = 8.77 Β± 0.12Β°** (500 ms)
|
| 293 |
-
|
| 294 |
-
Note: Prior works reporting ~6.9Β° MAE are **subject-specific**; TinyMyo trains a **single cross-subject model**, a significantly harder setting.
|
| 295 |
-
|
| 296 |
-
---
|
| 297 |
-
|
| 298 |
-
## 4.3 Speech Production & Recognition (Silent Speech)
|
| 299 |
-
|
| 300 |
-
Dataset: **Gaddy Silent Speech**
|
| 301 |
-
(8 channels, 1000 Hz, face/neck EMG)
|
| 302 |
-
|
| 303 |
-
### Speech Production (EMG β MFCC β HiFi-GAN β Audio)
|
| 304 |
-
|
| 305 |
-
Pipeline:
|
| 306 |
-
|
| 307 |
-
1. Residual downsampling
|
| 308 |
-
2. TinyMyo encoder
|
| 309 |
-
3. Linear projection β **26-dim MFCC**
|
| 310 |
-
4. HiFi-GAN vocoder
|
| 311 |
-
|
| 312 |
-
**WER:** **33.54 Β± 1.12%**
|
| 313 |
-
β state-of-the-art with **>90% fewer params** in the transduction model.
|
| 314 |
-
|
| 315 |
-
### Speech Recognition (EMG β Text)
|
| 316 |
-
|
| 317 |
-
* TinyMyo encoder
|
| 318 |
-
* Linear projection β **37 characters**
|
| 319 |
-
* **CTC** loss
|
| 320 |
-
* 4-gram LM + beam search
|
| 321 |
-
|
| 322 |
-
**WER:** **33.95 Β± 0.97%**
|
| 323 |
-
|
| 324 |
-
TinyMyo is EMG-only, unlike multimodal systems like MONA-LISA.
|
| 325 |
|
| 326 |
---
|
| 327 |
|
| 328 |
-
#
|
| 329 |
-
|
| 330 |
-
TinyMyo runs efficiently on **GAP9 (RISC-V)** via:
|
| 331 |
-
|
| 332 |
-
* **INT8 quantization**, including attention
|
| 333 |
-
* Multi-level streaming (L3 to L2 to L1)
|
| 334 |
-
* Integer LayerNorm, GELU, softmax
|
| 335 |
-
* Static memory arena via liveness analysis
|
| 336 |
|
| 337 |
-
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
This is the **first EMG foundation model demonstrated on a microcontroller**.
|
| 344 |
-
|
| 345 |
-
---
|
| 346 |
|
| 347 |
-
#
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
*
|
| 353 |
-
|
| 354 |
-
### Downstream Highlights
|
| 355 |
-
|
| 356 |
-
* **DB5:** 89.41%
|
| 357 |
-
* **EPN-612:** 96.74%
|
| 358 |
-
* **UCI EMG:** 97.56%
|
| 359 |
-
* **Neuromotor:** 0.153 CLER
|
| 360 |
-
* **DB8 Regression:** MAE 8.77Β°
|
| 361 |
-
* **Silent Speech Production:** 33.54% WER
|
| 362 |
-
* **Silent Speech Recognition:** 33.95% WER
|
| 363 |
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
---
|
| 367 |
|
| 368 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
-
|
| 371 |
-
**[BioFoundation repository](https://github.com/pulp-bio/BioFoundation)**.
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
*
|
| 381 |
-
* `CHECKPOINT_DIR` β checkpoint to load
|
| 382 |
|
| 383 |
---
|
| 384 |
|
| 385 |
-
##
|
| 386 |
-
|
| 387 |
-
- **
|
|
|
|
| 388 |
|
| 389 |
---
|
| 390 |
|
| 391 |
-
# π Citation
|
| 392 |
-
|
| 393 |
-
Please cite TinyMyo using:
|
| 394 |
|
| 395 |
```bibtex
|
| 396 |
-
@misc{
|
| 397 |
-
title={TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge},
|
| 398 |
author={Matteo Fasulo and Giusy Spacone and Thorir Mar Ingolfsson and Yawei Li and Luca Benini and Andrea Cossettini},
|
| 399 |
-
year={
|
| 400 |
eprint={2512.15729},
|
| 401 |
archivePrefix={arXiv},
|
| 402 |
primaryClass={eess.SP},
|
| 403 |
-
url={https://arxiv.org/abs/2512.15729},
|
| 404 |
}
|
| 405 |
-
```
|
| 406 |
-
|
| 407 |
-
---
|
| 408 |
-
|
| 409 |
-
# π§ Contact & Support
|
| 410 |
-
|
| 411 |
-
* Questions or issues?
|
| 412 |
-
Open an issue on the **BioFoundation GitHub repository**.
|
|
|
|
| 103 |
|
| 104 |
<div align="center">
|
| 105 |
<img src="https://raw.githubusercontent.com/MatteoFasulo/BioFoundation/refs/heads/TinyMyo/docs/model/logo/TinyMyo_logo.png" alt="TinyMyo Logo" width="400" />
|
| 106 |
+
<h1>TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge</h1>
|
| 107 |
+
</div>
|
| 108 |
+
<p align="center">
|
| 109 |
+
<a href="https://github.com/pulp-bio/BioFoundation"><img src ="https://img.shields.io/github/stars/pulp-bio/BioFoundation?color=ccf" alt="Github"></a>
|
| 110 |
+
<a href="https://creativecommons.org/licenses/by-nd/4.0/"><img src="https://img.shields.io/badge/License-CC_BY--ND_4.0-lightgrey.svg" alt="License"></a>
|
| 111 |
+
<a href="https://arxiv.org/abs/2512.15729"><img src="https://img.shields.io/badge/arXiv-2512.15729-b31b1b.svg" alt="Paper"></a>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
</p>
|
| 113 |
|
| 114 |
+
**TinyMyo** is a **3.6M-parameter** Transformer foundation model for surface EMG (sEMG), optimized for ultra-low-power edge deployment (GAP9 MCU). It demonstrates state-of-the-art performance across gesture classification, kinematic regression, and speech synthesis.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
---
|
| 117 |
|
| 118 |
+
## π Quick Start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
TinyMyo is built as a specialized model within the [BioFoundation](https://github.com/pulp-bio/BioFoundation) framework.
|
| 121 |
|
| 122 |
+
### 1. Requirements
|
| 123 |
+
- **Preprocessing:** Dependencies for data scripts are in `scripts/requirements.txt`.
|
| 124 |
+
- **BioFoundation:** Full framework requirements for training/inference are in the [GitHub repository](https://github.com/pulp-bio/BioFoundation/blob/main/requirements.txt).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
### 2. Preprocessing
|
| 127 |
+
Process raw datasets into HDF5 format:
|
| 128 |
+
```bash
|
| 129 |
+
python scripts/db5.py --data_dir $DATA_PATH/raw/ --save_dir $DATA_PATH/h5/ --seq_len 200 --stride 50
|
| 130 |
+
```
|
| 131 |
+
*See [scripts/README.md](scripts/README.md) for all dataset commands.*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
### 3. Fine-tuning
|
| 134 |
+
```bash
|
| 135 |
+
python run_train.py +experiment=TinyMyo_finetune pretrained_safetensors_path=/path/to/base.safetensors
|
| 136 |
+
```
|
| 137 |
|
| 138 |
---
|
| 139 |
|
| 140 |
+
## π§ Architecture & Pretraining
|
| 141 |
+
- **Core:** 8-layer Transformer encoder (192-dim embeddings, 3 heads).
|
| 142 |
+
- **Tokenization:** Channel-independent patching (20 samples/patch) with RoPE.
|
| 143 |
+
- **Data:** Pretrained on >480 GB of EMG (NinaPro DB6/7, EMG2Pose).
|
| 144 |
+
- **Specs:** 3.6M parameters, 4.0 GFLOPs.
|
| 145 |
|
| 146 |
+
## π― Benchmarks
|
|
|
|
| 147 |
|
| 148 |
+
| Task | Dataset | Metric | TinyMyo |
|
| 149 |
+
| :--- | :--- | :--- | :--- |
|
| 150 |
+
| **Gesture** | NinaPro DB5 | Accuracy | **89.41%** |
|
| 151 |
+
| **Gesture** | EPN-612 | Accuracy | **96.74%** |
|
| 152 |
+
| **Gesture** | UCI EMG | Accuracy | **97.56%** |
|
| 153 |
+
| **Regression**| NinaPro DB8 | MAE | **8.77Β°** |
|
| 154 |
+
| **Speech** | Gaddy (Speech Synthesis) | WER | **33.54%** |
|
| 155 |
+
| **Speech** | Gaddy (Speech Recognition) | WER | **33.95%** |
|
|
|
|
| 156 |
|
| 157 |
---
|
| 158 |
|
| 159 |
+
## β‘ Edge Performance (GAP9 MCU)
|
| 160 |
+
- **Inference:** 0.785 s
|
| 161 |
+
- **Energy:** 44.91 mJ
|
| 162 |
+
- **Power:** 57.18 mW
|
| 163 |
|
| 164 |
---
|
| 165 |
|
| 166 |
+
## π License & Citation
|
| 167 |
+
Weights are licensed under **CC BY-ND 4.0**. See [LICENSE](LICENSE) for details.
|
|
|
|
| 168 |
|
| 169 |
```bibtex
|
| 170 |
+
@misc{fasulo2026tinymyotinyfoundationmodel,
|
| 171 |
+
title={TinyMyo: a Tiny Foundation Model for Flexible EMG Signal Processing at the Edge},
|
| 172 |
author={Matteo Fasulo and Giusy Spacone and Thorir Mar Ingolfsson and Yawei Li and Luca Benini and Andrea Cossettini},
|
| 173 |
+
year={2026},
|
| 174 |
eprint={2512.15729},
|
| 175 |
archivePrefix={arXiv},
|
| 176 |
primaryClass={eess.SP},
|
| 177 |
+
url={https://arxiv.org/abs/2512.15729},
|
| 178 |
}
|
| 179 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/README.md
CHANGED
|
@@ -1,137 +1,38 @@
|
|
| 1 |
-
# Dataset Preparation
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
Substitute the `$DATA_PATH` environment variable with your path for saving the dataset.
|
| 10 |
-
|
| 11 |
-
The `seq_len` parameter in the scripts corresponds to the window size in samples, and the `stride` parameter corresponds to the step size between windows in samples. The sampling rate for the pretraining datasets is 2 kHz, while for the downstream datasets it is either 200 Hz or 2 kHz depending on the dataset.
|
| 12 |
-
|
| 13 |
-
The required libraries for running the scripts are located inside the `requirements.txt` file.
|
| 14 |
|
| 15 |
## Pretraining Datasets
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
```bash
|
| 24 |
-
python scripts/emg2pose.py \
|
| 25 |
-
--data_dir $DATA_PATH/datasets/emg2pose_data/ \
|
| 26 |
-
--save_dir $DATA_PATH/datasets/emg2pose_data/h5/ \
|
| 27 |
-
--seq_len 1000 \
|
| 28 |
-
--stride 500
|
| 29 |
-
```
|
| 30 |
-
|
| 31 |
-
### Ninapro DB6 (0.5 sec, 50% overlap)
|
| 32 |
-
|
| 33 |
-
```bash
|
| 34 |
-
python scripts/db6.py \
|
| 35 |
-
--data_dir $DATA_PATH/datasets/ninapro/DB6/ \
|
| 36 |
-
--save_dir $DATA_PATH/datasets/ninapro/DB6/h5/ \
|
| 37 |
-
--seq_len 1000 \
|
| 38 |
-
--stride 500
|
| 39 |
-
```
|
| 40 |
-
|
| 41 |
-
### Ninapro DB7 (0.5 sec, 50% overlap)
|
| 42 |
-
|
| 43 |
-
```bash
|
| 44 |
-
python scripts/db7.py \
|
| 45 |
-
--data_dir $DATA_PATH/datasets/ninapro/DB7/ \
|
| 46 |
-
--save_dir $DATA_PATH/datasets/ninapro/DB7/h5/ \
|
| 47 |
-
--seq_len 1000 \
|
| 48 |
-
--stride 500
|
| 49 |
-
```
|
| 50 |
|
| 51 |
---
|
| 52 |
|
| 53 |
## Downstream Datasets
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
``
|
| 60 |
-
python scripts/
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
```
|
| 66 |
-
|
| 67 |
-
### Ninapro DB5 (5 sec, 25% overlap)
|
| 68 |
-
|
| 69 |
-
```bash
|
| 70 |
-
python scripts/db5.py \
|
| 71 |
-
--data_dir $DATA_PATH/datasets/ninapro/DB5/ \
|
| 72 |
-
--save_dir $DATA_PATH/datasets/ninapro/DB5/h5/ \
|
| 73 |
-
--seq_len 1000 \
|
| 74 |
-
--stride 250
|
| 75 |
-
```
|
| 76 |
-
|
| 77 |
-
### EMG-EPN612 (1 sec, no overlap)
|
| 78 |
-
|
| 79 |
-
```bash
|
| 80 |
-
python scripts/epn.py \
|
| 81 |
-
--data_dir $DATA_PATH/datasets/EPN612/ \
|
| 82 |
-
--source_training $DATA_PATH/datasets/EPN612/trainingJSON/ \
|
| 83 |
-
--source_testing $DATA_PATH/datasets/EPN612/testingJSON/ \
|
| 84 |
-
--dest_dir $DATA_PATH/datasets/EPN612/h5/ \
|
| 85 |
-
--seq_len 200
|
| 86 |
-
```
|
| 87 |
-
|
| 88 |
-
### EMG-EPN612 (5 sec, no overlap)
|
| 89 |
-
|
| 90 |
-
```bash
|
| 91 |
-
python scripts/epn.py \
|
| 92 |
-
--data_dir $DATA_PATH/datasets/EPN612/ \
|
| 93 |
-
--source_training $DATA_PATH/datasets/EPN612/trainingJSON/ \
|
| 94 |
-
--source_testing $DATA_PATH/datasets/EPN612/testingJSON/ \
|
| 95 |
-
--dest_dir $DATA_PATH/datasets/EPN612/h5/ \
|
| 96 |
-
--seq_len 1000
|
| 97 |
-
```
|
| 98 |
-
|
| 99 |
-
### UCI EMG (1 sec, 25% overlap)
|
| 100 |
-
|
| 101 |
-
```bash
|
| 102 |
-
python scripts/uci.py \
|
| 103 |
-
--data_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
|
| 104 |
-
--save_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
|
| 105 |
-
--seq_len 200 \
|
| 106 |
-
--stride 50
|
| 107 |
-
```
|
| 108 |
-
|
| 109 |
-
### UCI EMG (5 sec, 25% overlap)
|
| 110 |
-
|
| 111 |
-
```bash
|
| 112 |
-
python scripts/uci.py \
|
| 113 |
-
--data_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
|
| 114 |
-
--save_dir $DATA_PATH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
|
| 115 |
-
--seq_len 1000 \
|
| 116 |
-
--stride 250
|
| 117 |
-
```
|
| 118 |
-
|
| 119 |
-
### Ninapro DB8 (100 ms, no overlap)
|
| 120 |
-
|
| 121 |
-
```bash
|
| 122 |
-
python scripts/db8.py \
|
| 123 |
-
--data_dir $DATA_PATH/datasets/ninapro/DB8/ \
|
| 124 |
-
--save_dir $DATA_PATH/datasets/ninapro/DB8/h5/ \
|
| 125 |
-
--seq_len 200 \
|
| 126 |
-
--stride 200
|
| 127 |
-
```
|
| 128 |
-
|
| 129 |
-
### Ninapro DB8 (500 ms, no overlap)
|
| 130 |
|
| 131 |
-
```bash
|
| 132 |
-
python scripts/db8.py \
|
| 133 |
-
--data_dir $DATA_PATH/datasets/ninapro/DB8/ \
|
| 134 |
-
--save_dir $DATA_PATH/datasets/ninapro/DB8/h5/ \
|
| 135 |
-
--seq_len 1000 \
|
| 136 |
-
--stride 1000
|
| 137 |
-
```
|
|
|
|
| 1 |
+
# Dataset Preparation
|
| 2 |
|
| 3 |
+
This guide provides commands to process raw EMG data into HDF5 format using sliding windows.
|
| 4 |
|
| 5 |
+
### Usage
|
| 6 |
+
- **Dependencies:** Install requirements specific to these scripts via `pip install -r scripts/requirements.txt`. Framework requirements for TinyMyo are in the [BioFoundation repository](https://github.com/pulp-bio/BioFoundation).
|
| 7 |
+
- Use `--download_data` if raw data is missing.
|
| 8 |
+
- Replace `$DATA_PATH` with your local storage path.
|
| 9 |
+
- `seq_len`: Window size (samples).
|
| 10 |
+
- `stride`: Step size (samples).
|
| 11 |
+
- Pretraining scripts use **2 kHz** sampling. Downstream scripts use **200 Hz** or **2 kHz**.
|
| 12 |
|
| 13 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## Pretraining Datasets
|
| 16 |
+
(0.5s windows, 50% overlap @ 2 kHz)
|
| 17 |
|
| 18 |
+
| Dataset | Size (GB) | Seq Len | Stride | Command |
|
| 19 |
+
| :--- | :--- | :--- | :--- | :--- |
|
| 20 |
+
| **EMG2Pose** | 431 | 1000 (0.5s) | 500 | `python scripts/emg2pose.py --data_dir $DATA_PATH/emg2pose_data/ --save_dir $DATA_PATH/emg2pose_data/h5/ --seq_len 1000 --stride 500` |
|
| 21 |
+
| **NinaPro DB6** | ~20 | 1000 (0.5s) | 500 | `python scripts/db6.py --data_dir $DATA_PATH/ninapro/DB6/ --save_dir $DATA_PATH/ninapro/DB6/h5/ --seq_len 1000 --stride 500` |
|
| 22 |
+
| **NinaPro DB7** | ~10 | 1000 (0.5s) | 500 | `python scripts/db7.py --data_dir $DATA_PATH/ninapro/DB7/ --save_dir $DATA_PATH/ninapro/DB7/h5/ --seq_len 1000 --stride 500` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
---
|
| 25 |
|
| 26 |
## Downstream Datasets
|
| 27 |
|
| 28 |
+
| Dataset | Metric | Seq Len | Stride | Command |
|
| 29 |
+
| :--- | :--- | :--- | :--- | :--- |
|
| 30 |
+
| **NinaPro DB5** | Gesture | 200 (1s) | 50 | `python scripts/db5.py --data_dir $DATA_PATH/ninapro/DB5/ --save_dir $DATA_PATH/ninapro/DB5/h5/ --seq_len 200 --stride 50` |
|
| 31 |
+
| **NinaPro DB5** | Gesture | 1000 (5s) | 250 | `python scripts/db5.py --data_dir $DATA_PATH/ninapro/DB5/ --save_dir $DATA_PATH/ninapro/DB5/h5/ --seq_len 1000 --stride 250` |
|
| 32 |
+
| **EMG-EPN612** | Gesture | 200 (1s) | N/A | `python scripts/epn.py --data_dir $DATA_PATH/EPN612/ --source_training $DATA_PATH/EPN612/trainingJSON/ --source_testing $DATA_PATH/EPN612/testingJSON/ --dest_dir $DATA_PATH/EPN612/h5/ --seq_len 200` |
|
| 33 |
+
| **EMG-EPN612** | Gesture | 1000 (5s) | N/A | `python scripts/epn.py --data_dir $DATA_PATH/EPN612/ --source_training $DATA_PATH/EPN612/trainingJSON/ --source_testing $DATA_PATH/EPN612/testingJSON/ --dest_dir $DATA_PATH/EPN612/h5/ --seq_len 1000` |
|
| 34 |
+
| **UCI EMG** | Gesture | 200 (1s) | 50 | `python scripts/uci.py --data_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/ --save_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/h5/ --seq_len 200 --stride 50` |
|
| 35 |
+
| **UCI EMG** | Gesture | 1000 (5s) | 250 | `python scripts/uci.py --data_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/ --save_dir $DATA_PATH/UCI_EMG/EMG_data_for_gestures-master/h5/ --seq_len 1000 --stride 250` |
|
| 36 |
+
| **NinaPro DB8** | Regression | 200 (0.1s) | 200 | `python scripts/db8.py --data_dir $DATA_PATH/ninapro/DB8/ --save_dir $DATA_PATH/ninapro/DB8/h5/ --seq_len 200 --stride 200` |
|
| 37 |
+
| **NinaPro DB8** | Regression | 1000 (0.5s) | 1000 | `python scripts/db8.py --data_dir $DATA_PATH/ninapro/DB8/ --save_dir $DATA_PATH/ninapro/DB8/h5/ --seq_len 1000 --stride 1000` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/db5.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
| 3 |
|
| 4 |
import h5py
|
| 5 |
import numpy as np
|
|
@@ -7,30 +8,77 @@ import scipy.io
|
|
| 7 |
import scipy.signal as signal
|
| 8 |
from scipy.signal import iirnotch
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
scale = np.random.uniform(*scale_range)
|
| 16 |
return sig * scale
|
| 17 |
|
| 18 |
|
| 19 |
-
def random_time_jitter(sig, jitter_ratio=0.01):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
T, D = sig.shape
|
| 21 |
std_ch = np.std(sig, axis=0)
|
| 22 |
noise = np.random.randn(T, D) * (jitter_ratio * std_ch)
|
| 23 |
return sig + noise
|
| 24 |
|
| 25 |
|
| 26 |
-
def random_channel_dropout(sig, dropout_prob=0.05):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
T, D = sig.shape
|
| 28 |
mask = np.random.rand(D) < dropout_prob
|
| 29 |
sig[:, mask] = 0.0
|
| 30 |
return sig
|
| 31 |
|
| 32 |
|
| 33 |
-
def augment_one_sample(seg):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
out = seg.copy()
|
| 35 |
out = random_amplitude_scale(out, (0.9, 1.1))
|
| 36 |
out = random_time_jitter(out, 0.01)
|
|
@@ -38,7 +86,20 @@ def augment_one_sample(seg):
|
|
| 38 |
return out
|
| 39 |
|
| 40 |
|
| 41 |
-
def augment_train_data(data, labels, factor=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if factor <= 0 or data.shape[0] == 0:
|
| 43 |
return data, labels
|
| 44 |
aug_segs = [data]
|
|
@@ -55,8 +116,19 @@ def augment_train_data(data, labels, factor=3):
|
|
| 55 |
return new_data, new_labels
|
| 56 |
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
b, a = iirnotch(notch_freq, Q, fs)
|
| 61 |
out = np.zeros_like(data)
|
| 62 |
for ch in range(data.shape[1]):
|
|
@@ -64,7 +136,25 @@ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=200.0):
|
|
| 64 |
return out
|
| 65 |
|
| 66 |
|
| 67 |
-
def bandpass_filter_emg(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
nyq = 0.5 * fs
|
| 69 |
low = lowcut / nyq
|
| 70 |
high = highcut / nyq
|
|
@@ -75,8 +165,28 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
|
|
| 75 |
return out
|
| 76 |
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
segs, lbls, reps = [], [], []
|
| 81 |
N = len(label)
|
| 82 |
for start in range(0, N, stride):
|
|
@@ -94,7 +204,6 @@ def process_emg_features(emg, label, rerep, window_size=1024, stride=512):
|
|
| 94 |
return np.array(segs), np.array(lbls), np.array(reps)
|
| 95 |
|
| 96 |
|
| 97 |
-
# ==== Main pipeline ====
|
| 98 |
def main():
|
| 99 |
import argparse
|
| 100 |
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
from typing import Tuple, List, Optional, Union, Dict, Any, Callable
|
| 4 |
|
| 5 |
import h5py
|
| 6 |
import numpy as np
|
|
|
|
| 8 |
import scipy.signal as signal
|
| 9 |
from scipy.signal import iirnotch
|
| 10 |
|
| 11 |
+
def sequence_to_seconds(seq_len: int, fs: float) -> float:
|
| 12 |
+
"""Converts a sequence length in samples to time in seconds.
|
| 13 |
|
| 14 |
+
Args:
|
| 15 |
+
seq_len (int): The number of samples in the sequence.
|
| 16 |
+
fs (float): The sampling frequency in Hz.
|
| 17 |
|
| 18 |
+
Returns:
|
| 19 |
+
float: The duration of the sequence in seconds.
|
| 20 |
+
"""
|
| 21 |
+
return seq_len / fs
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def random_amplitude_scale(sig: np.ndarray, scale_range: Tuple[float, float] = (0.9, 1.1)) -> np.ndarray:
|
| 25 |
+
"""Applies random amplitude scaling to the input signal.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
sig (np.ndarray): The input signal array of shape (T, D).
|
| 29 |
+
scale_range (Tuple[float, float], optional): The range [min, max] for the scaling factor.
|
| 30 |
+
Defaults to (0.9, 1.1).
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
np.ndarray: The scaled signal array.
|
| 34 |
+
"""
|
| 35 |
scale = np.random.uniform(*scale_range)
|
| 36 |
return sig * scale
|
| 37 |
|
| 38 |
|
| 39 |
+
def random_time_jitter(sig: np.ndarray, jitter_ratio: float = 0.01) -> np.ndarray:
|
| 40 |
+
"""Adds random Gaussian noise (jitter) to the input signal.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
sig (np.ndarray): The input signal array of shape (T, D).
|
| 44 |
+
jitter_ratio (float, optional): The ratio to scale the noise relative to
|
| 45 |
+
each channel's standard deviation. Defaults to 0.01.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
np.ndarray: The signal with added jitter.
|
| 49 |
+
"""
|
| 50 |
T, D = sig.shape
|
| 51 |
std_ch = np.std(sig, axis=0)
|
| 52 |
noise = np.random.randn(T, D) * (jitter_ratio * std_ch)
|
| 53 |
return sig + noise
|
| 54 |
|
| 55 |
|
| 56 |
+
def random_channel_dropout(sig: np.ndarray, dropout_prob: float = 0.05) -> np.ndarray:
|
| 57 |
+
"""Randomly zeros out channels in the signal based on a probability.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
sig (np.ndarray): The input signal array of shape (T, D).
|
| 61 |
+
dropout_prob (float, optional): Probability of dropping each channel.
|
| 62 |
+
Defaults to 0.05.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
np.ndarray: The signal with dropped channels.
|
| 66 |
+
"""
|
| 67 |
T, D = sig.shape
|
| 68 |
mask = np.random.rand(D) < dropout_prob
|
| 69 |
sig[:, mask] = 0.0
|
| 70 |
return sig
|
| 71 |
|
| 72 |
|
| 73 |
+
def augment_one_sample(seg: np.ndarray) -> np.ndarray:
|
| 74 |
+
"""Applies a sequence of random augmentations to a single signal segment.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
seg (np.ndarray): Single signal segment of shape (window_size, n_ch).
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
np.ndarray: The augmented signal segment.
|
| 81 |
+
"""
|
| 82 |
out = seg.copy()
|
| 83 |
out = random_amplitude_scale(out, (0.9, 1.1))
|
| 84 |
out = random_time_jitter(out, 0.01)
|
|
|
|
| 86 |
return out
|
| 87 |
|
| 88 |
|
| 89 |
+
def augment_train_data(data: np.ndarray, labels: np.ndarray, factor: int = 3) -> Tuple[np.ndarray, np.ndarray]:
|
| 90 |
+
"""Augments the training dataset by creating multiple versions of each sample.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
data (np.ndarray): The input dataset of shape (N, window_size, n_ch).
|
| 94 |
+
labels (np.ndarray): The corresponding labels of shape (N,).
|
| 95 |
+
factor (int, optional): The number of augmented versions to create for each sample.
|
| 96 |
+
Defaults to 3.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Tuple[np.ndarray, np.ndarray]: A tuple containing:
|
| 100 |
+
- The augmented dataset.
|
| 101 |
+
- The augmented labels.
|
| 102 |
+
"""
|
| 103 |
if factor <= 0 or data.shape[0] == 0:
|
| 104 |
return data, labels
|
| 105 |
aug_segs = [data]
|
|
|
|
| 116 |
return new_data, new_labels
|
| 117 |
|
| 118 |
|
| 119 |
+
def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 200.0) -> np.ndarray:
|
| 120 |
+
"""Applies a notch filter to remove power line interference.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
data (np.ndarray): The input signal array of shape (T, D).
|
| 124 |
+
notch_freq (float, optional): The frequency to be removed (e.g., 50Hz or 60Hz).
|
| 125 |
+
Defaults to 50.0.
|
| 126 |
+
Q (float, optional): The quality factor. Defaults to 30.0.
|
| 127 |
+
fs (float, optional): The sampling frequency of the signal. Defaults to 200.0.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
np.ndarray: The filtered signal array.
|
| 131 |
+
"""
|
| 132 |
b, a = iirnotch(notch_freq, Q, fs)
|
| 133 |
out = np.zeros_like(data)
|
| 134 |
for ch in range(data.shape[1]):
|
|
|
|
| 136 |
return out
|
| 137 |
|
| 138 |
|
| 139 |
+
def bandpass_filter_emg(
|
| 140 |
+
emg: np.ndarray,
|
| 141 |
+
lowcut: float = 20.0,
|
| 142 |
+
highcut: float = 90.0,
|
| 143 |
+
fs: float = 200.0,
|
| 144 |
+
order: int = 4
|
| 145 |
+
) -> np.ndarray:
|
| 146 |
+
"""Applies a Butterworth bandpass filter to the EMG signal.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
emg (np.ndarray): The input signal array of shape (T, D).
|
| 150 |
+
lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
|
| 151 |
+
highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
|
| 152 |
+
fs (float, optional): The sampling frequency of the signal. Defaults to 200.0.
|
| 153 |
+
order (int, optional): The order of the filter. Defaults to 4.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
np.ndarray: The bandpass filtered signal array.
|
| 157 |
+
"""
|
| 158 |
nyq = 0.5 * fs
|
| 159 |
low = lowcut / nyq
|
| 160 |
high = highcut / nyq
|
|
|
|
| 165 |
return out
|
| 166 |
|
| 167 |
|
| 168 |
+
def process_emg_features(
|
| 169 |
+
emg: np.ndarray,
|
| 170 |
+
label: np.ndarray,
|
| 171 |
+
rerep: np.ndarray,
|
| 172 |
+
window_size: int = 1024,
|
| 173 |
+
stride: int = 512
|
| 174 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 175 |
+
"""Segments raw EMG signals into overlapping windows.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
emg (np.ndarray): Raw EMG data of shape (T, n_ch).
|
| 179 |
+
label (np.ndarray): Gesture labels of shape (T,).
|
| 180 |
+
rerep (np.ndarray): Repetition indices of shape (T,).
|
| 181 |
+
window_size (int, optional): Number of samples per window. Defaults to 1024.
|
| 182 |
+
stride (int, optional): Number of samples to shift between windows. Defaults to 512.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
|
| 186 |
+
- windowed segments (N, window_size, n_ch).
|
| 187 |
+
- labels for each window (N,).
|
| 188 |
+
- repetition indices for each window (N,).
|
| 189 |
+
"""
|
| 190 |
segs, lbls, reps = [], [], []
|
| 191 |
N = len(label)
|
| 192 |
for start in range(0, N, stride):
|
|
|
|
| 204 |
return np.array(segs), np.array(lbls), np.array(reps)
|
| 205 |
|
| 206 |
|
|
|
|
| 207 |
def main():
|
| 208 |
import argparse
|
| 209 |
|
scripts/db6.py
CHANGED
|
@@ -5,12 +5,13 @@ import h5py
|
|
| 5 |
import numpy as np
|
| 6 |
import scipy.io
|
| 7 |
import scipy.signal as signal
|
|
|
|
| 8 |
from scipy.signal import iirnotch
|
|
|
|
| 9 |
|
| 10 |
sequence_to_seconds = lambda seq_len, fs: seq_len / fs
|
| 11 |
|
| 12 |
|
| 13 |
-
# βββββββββββββββ Filtering ββββββββββββββββββ
|
| 14 |
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
| 15 |
"""Notch-filter every channel independently."""
|
| 16 |
b, a = iirnotch(notch_freq, Q, fs)
|
|
@@ -29,7 +30,6 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
|
| 29 |
return out
|
| 30 |
|
| 31 |
|
| 32 |
-
# βββββββββββββββ Sliding window ββββββββββββββ
|
| 33 |
def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
| 34 |
"""
|
| 35 |
Segment EMG with a sliding window.
|
|
@@ -49,7 +49,64 @@ def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
|
| 49 |
return np.array(segments), np.array(labels), np.array(reps)
|
| 50 |
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def main():
|
| 54 |
import argparse
|
| 55 |
|
|
@@ -65,6 +122,18 @@ def main():
|
|
| 65 |
type=int,
|
| 66 |
help="Step size between windows in samples for segmentation.",
|
| 67 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
args = args.parse_args()
|
| 69 |
data_dir = args.data_dir # input folder with .mat files
|
| 70 |
save_dir = args.save_dir # output folder for .h5 files
|
|
@@ -105,62 +174,30 @@ def main():
|
|
| 105 |
"test": {"data": [], "label": []},
|
| 106 |
}
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
if
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
-
for
|
| 119 |
-
|
|
|
|
| 120 |
continue
|
| 121 |
-
mat_path = os.path.join(subj_path, mat_file)
|
| 122 |
-
mat = scipy.io.loadmat(mat_path)
|
| 123 |
-
|
| 124 |
-
emg = mat["emg"] # (N, 16)
|
| 125 |
-
label = mat["restimulus"].ravel()
|
| 126 |
-
rerep = mat["rerepetition"].ravel()
|
| 127 |
-
|
| 128 |
-
# drop empty channels (index 8, 9 β 0-based)
|
| 129 |
-
emg = np.delete(emg, [8, 9], axis=1) # now (N, 14)
|
| 130 |
-
|
| 131 |
-
# filtering
|
| 132 |
-
emg = bandpass_filter_emg(emg, 20, 450, fs=fs)
|
| 133 |
-
emg = notch_filter(emg, 50, 30, fs=fs)
|
| 134 |
-
|
| 135 |
-
# z-score per channel
|
| 136 |
-
mu = emg.mean(axis=0)
|
| 137 |
-
sd = emg.std(axis=0, ddof=1)
|
| 138 |
-
sd[sd == 0] = 1.0
|
| 139 |
-
emg = (emg - mu) / sd
|
| 140 |
-
|
| 141 |
-
# windowing
|
| 142 |
-
seg, lbl, rep = sliding_window_segment(
|
| 143 |
-
emg, label, rerep, window_size, stride
|
| 144 |
-
)
|
| 145 |
-
subj_seg.append(seg)
|
| 146 |
-
subj_lbl.append(lbl)
|
| 147 |
-
subj_rep.append(rep)
|
| 148 |
-
|
| 149 |
-
if not subj_seg:
|
| 150 |
-
continue
|
| 151 |
-
|
| 152 |
-
seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
|
| 153 |
-
lbl = np.concatenate(subj_lbl)
|
| 154 |
-
rep = np.concatenate(subj_rep)
|
| 155 |
-
|
| 156 |
-
# split by repetition id
|
| 157 |
-
for split_name, mask in (
|
| 158 |
-
("train", np.isin(rep, train_reps)),
|
| 159 |
-
("val", np.isin(rep, val_reps)),
|
| 160 |
-
("test", np.isin(rep, test_reps)),
|
| 161 |
-
):
|
| 162 |
-
X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
|
| 163 |
-
y = lbl[mask]
|
| 164 |
splits[split_name]["data"].append(X)
|
| 165 |
splits[split_name]["label"].append(y)
|
| 166 |
|
|
@@ -177,9 +214,20 @@ def main():
|
|
| 177 |
else np.empty((0,), dtype=int)
|
| 178 |
)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
uniq, cnt = np.unique(y, return_counts=True)
|
| 185 |
print(f"\n{split.upper()} β X={X.shape}, label distribution:")
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import scipy.io
|
| 7 |
import scipy.signal as signal
|
| 8 |
+
from joblib import Parallel, delayed
|
| 9 |
from scipy.signal import iirnotch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
|
| 12 |
sequence_to_seconds = lambda seq_len, fs: seq_len / fs
|
| 13 |
|
| 14 |
|
|
|
|
| 15 |
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
| 16 |
"""Notch-filter every channel independently."""
|
| 17 |
b, a = iirnotch(notch_freq, Q, fs)
|
|
|
|
| 30 |
return out
|
| 31 |
|
| 32 |
|
|
|
|
| 33 |
def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
| 34 |
"""
|
| 35 |
Segment EMG with a sliding window.
|
|
|
|
| 49 |
return np.array(segments), np.array(labels), np.array(reps)
|
| 50 |
|
| 51 |
|
| 52 |
+
def process_subject(
|
| 53 |
+
subj_path,
|
| 54 |
+
window_size,
|
| 55 |
+
stride,
|
| 56 |
+
fs,
|
| 57 |
+
train_reps,
|
| 58 |
+
val_reps,
|
| 59 |
+
test_reps,
|
| 60 |
+
):
|
| 61 |
+
subj_seg, subj_lbl, subj_rep = [], [], []
|
| 62 |
+
|
| 63 |
+
for mat_file in sorted(os.listdir(subj_path)):
|
| 64 |
+
if not mat_file.endswith(".mat"):
|
| 65 |
+
continue
|
| 66 |
+
mat_path = os.path.join(subj_path, mat_file)
|
| 67 |
+
mat = scipy.io.loadmat(mat_path)
|
| 68 |
+
|
| 69 |
+
emg = mat["emg"] # (N, 16)
|
| 70 |
+
label = mat["restimulus"].ravel()
|
| 71 |
+
rerep = mat["rerepetition"].ravel()
|
| 72 |
+
|
| 73 |
+
emg = np.delete(emg, [8, 9], axis=1) # now (N, 14)
|
| 74 |
+
emg = bandpass_filter_emg(emg, 20, 450, fs=fs)
|
| 75 |
+
emg = notch_filter(emg, 50, 30, fs=fs)
|
| 76 |
+
|
| 77 |
+
mu = emg.mean(axis=0)
|
| 78 |
+
sd = emg.std(axis=0, ddof=1)
|
| 79 |
+
sd[sd == 0] = 1.0
|
| 80 |
+
emg = (emg - mu) / sd
|
| 81 |
+
|
| 82 |
+
seg, lbl, rep = sliding_window_segment(emg, label, rerep, window_size, stride)
|
| 83 |
+
subj_seg.append(seg)
|
| 84 |
+
subj_lbl.append(lbl)
|
| 85 |
+
subj_rep.append(rep)
|
| 86 |
+
|
| 87 |
+
if not subj_seg:
|
| 88 |
+
return {
|
| 89 |
+
"train": (np.empty((0, 14, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
|
| 90 |
+
"val": (np.empty((0, 14, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
|
| 91 |
+
"test": (np.empty((0, 14, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
seg = np.concatenate(subj_seg, axis=0)
|
| 95 |
+
lbl = np.concatenate(subj_lbl)
|
| 96 |
+
rep = np.concatenate(subj_rep)
|
| 97 |
+
|
| 98 |
+
out = {}
|
| 99 |
+
for split_name, mask in (
|
| 100 |
+
("train", np.isin(rep, train_reps)),
|
| 101 |
+
("val", np.isin(rep, val_reps)),
|
| 102 |
+
("test", np.isin(rep, test_reps)),
|
| 103 |
+
):
|
| 104 |
+
X = seg[mask].transpose(0, 2, 1).astype(np.float32)
|
| 105 |
+
y = lbl[mask].astype(np.int64)
|
| 106 |
+
out[split_name] = (X, y)
|
| 107 |
+
return out
|
| 108 |
+
|
| 109 |
+
|
| 110 |
def main():
|
| 111 |
import argparse
|
| 112 |
|
|
|
|
| 122 |
type=int,
|
| 123 |
help="Step size between windows in samples for segmentation.",
|
| 124 |
)
|
| 125 |
+
args.add_argument(
|
| 126 |
+
"--group_size",
|
| 127 |
+
type=int,
|
| 128 |
+
default=1000,
|
| 129 |
+
help="Number of samples per group in the output HDF5 file.",
|
| 130 |
+
)
|
| 131 |
+
args.add_argument(
|
| 132 |
+
"--n_jobs",
|
| 133 |
+
type=int,
|
| 134 |
+
default=-1,
|
| 135 |
+
help="Number of subjects to process in parallel. -1 means all cores.",
|
| 136 |
+
)
|
| 137 |
args = args.parse_args()
|
| 138 |
data_dir = args.data_dir # input folder with .mat files
|
| 139 |
save_dir = args.save_dir # output folder for .h5 files
|
|
|
|
| 174 |
"test": {"data": [], "label": []},
|
| 175 |
}
|
| 176 |
|
| 177 |
+
subject_paths = [
|
| 178 |
+
os.path.join(data_dir, subj)
|
| 179 |
+
for subj in sorted(os.listdir(data_dir))
|
| 180 |
+
if os.path.isdir(os.path.join(data_dir, subj))
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
subject_results = Parallel(n_jobs=args.n_jobs)(
|
| 184 |
+
delayed(process_subject)(
|
| 185 |
+
subj_path,
|
| 186 |
+
window_size,
|
| 187 |
+
stride,
|
| 188 |
+
fs,
|
| 189 |
+
train_reps,
|
| 190 |
+
val_reps,
|
| 191 |
+
test_reps,
|
| 192 |
+
)
|
| 193 |
+
for subj_path in tqdm(subject_paths, desc="Processing subjects")
|
| 194 |
+
)
|
| 195 |
|
| 196 |
+
for result in subject_results:
|
| 197 |
+
for split_name in ["train", "val", "test"]:
|
| 198 |
+
X, y = result[split_name]
|
| 199 |
+
if X.shape[0] == 0:
|
| 200 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
splits[split_name]["data"].append(X)
|
| 202 |
splits[split_name]["label"].append(y)
|
| 203 |
|
|
|
|
| 214 |
else np.empty((0,), dtype=int)
|
| 215 |
)
|
| 216 |
|
| 217 |
+
out_path = os.path.join(save_dir, f"{split}.h5")
|
| 218 |
+
if os.path.exists(out_path):
|
| 219 |
+
os.remove(out_path)
|
| 220 |
+
print(f"Removed existing file {out_path} to avoid overwrite issues.")
|
| 221 |
+
|
| 222 |
+
with h5py.File(out_path, "w") as h5f:
|
| 223 |
+
for group_idx, start in enumerate(range(0, X.shape[0], args.group_size)):
|
| 224 |
+
end = min(start + args.group_size, X.shape[0])
|
| 225 |
+
x_chunk = X[start:end].astype(np.float32)
|
| 226 |
+
y_chunk = y[start:end].astype(np.int64)
|
| 227 |
+
|
| 228 |
+
grp = h5f.create_group(f"data_group_{group_idx}")
|
| 229 |
+
grp.create_dataset("X", data=x_chunk)
|
| 230 |
+
grp.create_dataset("y", data=y_chunk)
|
| 231 |
|
| 232 |
uniq, cnt = np.unique(y, return_counts=True)
|
| 233 |
print(f"\n{split.upper()} β X={X.shape}, label distribution:")
|
scripts/db7.py
CHANGED
|
@@ -5,12 +5,13 @@ import h5py
|
|
| 5 |
import numpy as np
|
| 6 |
import scipy.io
|
| 7 |
import scipy.signal as signal
|
|
|
|
| 8 |
from scipy.signal import iirnotch
|
|
|
|
| 9 |
|
| 10 |
sequence_to_seconds = lambda seq_len, fs: seq_len / fs
|
| 11 |
|
| 12 |
|
| 13 |
-
# βββββββββββββββ Filtering ββββββββββββββββββ
|
| 14 |
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
| 15 |
"""Notch-filter every channel independently."""
|
| 16 |
b, a = iirnotch(notch_freq, Q, fs)
|
|
@@ -29,7 +30,6 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
|
| 29 |
return out
|
| 30 |
|
| 31 |
|
| 32 |
-
# βββββββββββββββ Sliding window ββββββββββββββ
|
| 33 |
def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
| 34 |
"""
|
| 35 |
Segment EMG with a sliding window.
|
|
@@ -49,7 +49,63 @@ def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
|
| 49 |
return np.array(segments), np.array(labels), np.array(reps)
|
| 50 |
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def main():
|
| 54 |
import argparse
|
| 55 |
|
|
@@ -65,6 +121,18 @@ def main():
|
|
| 65 |
type=int,
|
| 66 |
help="Step size between windows in samples for segmentation.",
|
| 67 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
args = args.parse_args()
|
| 69 |
data_dir = args.data_dir # input folder with .mat files
|
| 70 |
save_dir = args.save_dir # output folder for .h5 files
|
|
@@ -100,59 +168,30 @@ def main():
|
|
| 100 |
"test": {"data": [], "label": []},
|
| 101 |
}
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
if
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
for
|
| 114 |
-
|
|
|
|
| 115 |
continue
|
| 116 |
-
mat_path = os.path.join(subj_path, mat_file)
|
| 117 |
-
mat = scipy.io.loadmat(mat_path)
|
| 118 |
-
|
| 119 |
-
emg = mat["emg"] # (N, 16)
|
| 120 |
-
label = mat["restimulus"].ravel()
|
| 121 |
-
rerep = mat["rerepetition"].ravel()
|
| 122 |
-
|
| 123 |
-
# filtering
|
| 124 |
-
emg = bandpass_filter_emg(emg, 20.0, 450.0, fs=fs)
|
| 125 |
-
emg = notch_filter(emg, 50.0, 30.0, fs=fs)
|
| 126 |
-
|
| 127 |
-
# z-score per channel
|
| 128 |
-
mu = emg.mean(axis=0)
|
| 129 |
-
sd = emg.std(axis=0, ddof=1)
|
| 130 |
-
sd[sd == 0] = 1.0
|
| 131 |
-
emg = (emg - mu) / sd
|
| 132 |
-
|
| 133 |
-
# windowing
|
| 134 |
-
seg, lbl, rep = sliding_window_segment(
|
| 135 |
-
emg, label, rerep, window_size, stride
|
| 136 |
-
)
|
| 137 |
-
subj_seg.append(seg)
|
| 138 |
-
subj_lbl.append(lbl)
|
| 139 |
-
subj_rep.append(rep)
|
| 140 |
-
|
| 141 |
-
if not subj_seg:
|
| 142 |
-
continue
|
| 143 |
-
|
| 144 |
-
seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
|
| 145 |
-
lbl = np.concatenate(subj_lbl)
|
| 146 |
-
rep = np.concatenate(subj_rep)
|
| 147 |
-
|
| 148 |
-
# split by repetition id
|
| 149 |
-
for split_name, mask in (
|
| 150 |
-
("train", np.isin(rep, train_reps)),
|
| 151 |
-
("val", np.isin(rep, val_reps)),
|
| 152 |
-
("test", np.isin(rep, test_reps)),
|
| 153 |
-
):
|
| 154 |
-
X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
|
| 155 |
-
y = lbl[mask]
|
| 156 |
splits[split_name]["data"].append(X)
|
| 157 |
splits[split_name]["label"].append(y)
|
| 158 |
|
|
@@ -161,7 +200,7 @@ def main():
|
|
| 161 |
X = (
|
| 162 |
np.concatenate(splits[split]["data"], axis=0)
|
| 163 |
if splits[split]["data"]
|
| 164 |
-
else np.empty((0,
|
| 165 |
)
|
| 166 |
y = (
|
| 167 |
np.concatenate(splits[split]["label"], axis=0)
|
|
@@ -169,9 +208,20 @@ def main():
|
|
| 169 |
else np.empty((0,), dtype=int)
|
| 170 |
)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
uniq, cnt = np.unique(y, return_counts=True)
|
| 177 |
print(f"\n{split.upper()} β X={X.shape}, label distribution:")
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import scipy.io
|
| 7 |
import scipy.signal as signal
|
| 8 |
+
from joblib import Parallel, delayed
|
| 9 |
from scipy.signal import iirnotch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
|
| 12 |
sequence_to_seconds = lambda seq_len, fs: seq_len / fs
|
| 13 |
|
| 14 |
|
|
|
|
| 15 |
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
| 16 |
"""Notch-filter every channel independently."""
|
| 17 |
b, a = iirnotch(notch_freq, Q, fs)
|
|
|
|
| 30 |
return out
|
| 31 |
|
| 32 |
|
|
|
|
| 33 |
def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
| 34 |
"""
|
| 35 |
Segment EMG with a sliding window.
|
|
|
|
| 49 |
return np.array(segments), np.array(labels), np.array(reps)
|
| 50 |
|
| 51 |
|
| 52 |
+
def process_subject(
|
| 53 |
+
subj_path,
|
| 54 |
+
window_size,
|
| 55 |
+
stride,
|
| 56 |
+
fs,
|
| 57 |
+
train_reps,
|
| 58 |
+
val_reps,
|
| 59 |
+
test_reps,
|
| 60 |
+
):
|
| 61 |
+
subj_seg, subj_lbl, subj_rep = [], [], []
|
| 62 |
+
|
| 63 |
+
for mat_file in sorted(os.listdir(subj_path)):
|
| 64 |
+
if not mat_file.endswith(".mat"):
|
| 65 |
+
continue
|
| 66 |
+
mat_path = os.path.join(subj_path, mat_file)
|
| 67 |
+
mat = scipy.io.loadmat(mat_path)
|
| 68 |
+
|
| 69 |
+
emg = mat["emg"] # (N, 16)
|
| 70 |
+
label = mat["restimulus"].ravel()
|
| 71 |
+
rerep = mat["rerepetition"].ravel()
|
| 72 |
+
|
| 73 |
+
emg = bandpass_filter_emg(emg, 20.0, 450.0, fs=fs)
|
| 74 |
+
emg = notch_filter(emg, 50.0, 30.0, fs=fs)
|
| 75 |
+
|
| 76 |
+
mu = emg.mean(axis=0)
|
| 77 |
+
sd = emg.std(axis=0, ddof=1)
|
| 78 |
+
sd[sd == 0] = 1.0
|
| 79 |
+
emg = (emg - mu) / sd
|
| 80 |
+
|
| 81 |
+
seg, lbl, rep = sliding_window_segment(emg, label, rerep, window_size, stride)
|
| 82 |
+
subj_seg.append(seg)
|
| 83 |
+
subj_lbl.append(lbl)
|
| 84 |
+
subj_rep.append(rep)
|
| 85 |
+
|
| 86 |
+
if not subj_seg:
|
| 87 |
+
return {
|
| 88 |
+
"train": (np.empty((0, 16, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
|
| 89 |
+
"val": (np.empty((0, 16, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
|
| 90 |
+
"test": (np.empty((0, 16, window_size), dtype=np.float32), np.empty((0,), dtype=np.int64)),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
seg = np.concatenate(subj_seg, axis=0)
|
| 94 |
+
lbl = np.concatenate(subj_lbl)
|
| 95 |
+
rep = np.concatenate(subj_rep)
|
| 96 |
+
|
| 97 |
+
out = {}
|
| 98 |
+
for split_name, mask in (
|
| 99 |
+
("train", np.isin(rep, train_reps)),
|
| 100 |
+
("val", np.isin(rep, val_reps)),
|
| 101 |
+
("test", np.isin(rep, test_reps)),
|
| 102 |
+
):
|
| 103 |
+
X = seg[mask].transpose(0, 2, 1).astype(np.float32)
|
| 104 |
+
y = lbl[mask].astype(np.int64)
|
| 105 |
+
out[split_name] = (X, y)
|
| 106 |
+
return out
|
| 107 |
+
|
| 108 |
+
|
| 109 |
def main():
|
| 110 |
import argparse
|
| 111 |
|
|
|
|
| 121 |
type=int,
|
| 122 |
help="Step size between windows in samples for segmentation.",
|
| 123 |
)
|
| 124 |
+
args.add_argument(
|
| 125 |
+
"--group_size",
|
| 126 |
+
type=int,
|
| 127 |
+
default=1000,
|
| 128 |
+
help="Number of samples per group in the output HDF5 file.",
|
| 129 |
+
)
|
| 130 |
+
args.add_argument(
|
| 131 |
+
"--n_jobs",
|
| 132 |
+
type=int,
|
| 133 |
+
default=-1,
|
| 134 |
+
help="Number of subjects to process in parallel. -1 means all cores.",
|
| 135 |
+
)
|
| 136 |
args = args.parse_args()
|
| 137 |
data_dir = args.data_dir # input folder with .mat files
|
| 138 |
save_dir = args.save_dir # output folder for .h5 files
|
|
|
|
| 168 |
"test": {"data": [], "label": []},
|
| 169 |
}
|
| 170 |
|
| 171 |
+
subject_paths = [
|
| 172 |
+
os.path.join(data_dir, subj)
|
| 173 |
+
for subj in sorted(os.listdir(data_dir))
|
| 174 |
+
if os.path.isdir(os.path.join(data_dir, subj))
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
subject_results = Parallel(n_jobs=args.n_jobs)(
|
| 178 |
+
delayed(process_subject)(
|
| 179 |
+
subj_path,
|
| 180 |
+
window_size,
|
| 181 |
+
stride,
|
| 182 |
+
fs,
|
| 183 |
+
train_reps,
|
| 184 |
+
val_reps,
|
| 185 |
+
test_reps,
|
| 186 |
+
)
|
| 187 |
+
for subj_path in tqdm(subject_paths, desc="Processing subjects")
|
| 188 |
+
)
|
| 189 |
|
| 190 |
+
for result in subject_results:
|
| 191 |
+
for split_name in ["train", "val", "test"]:
|
| 192 |
+
X, y = result[split_name]
|
| 193 |
+
if X.shape[0] == 0:
|
| 194 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
splits[split_name]["data"].append(X)
|
| 196 |
splits[split_name]["label"].append(y)
|
| 197 |
|
|
|
|
| 200 |
X = (
|
| 201 |
np.concatenate(splits[split]["data"], axis=0)
|
| 202 |
if splits[split]["data"]
|
| 203 |
+
else np.empty((0, 16, window_size))
|
| 204 |
)
|
| 205 |
y = (
|
| 206 |
np.concatenate(splits[split]["label"], axis=0)
|
|
|
|
| 208 |
else np.empty((0,), dtype=int)
|
| 209 |
)
|
| 210 |
|
| 211 |
+
out_path = os.path.join(save_dir, f"{split}.h5")
|
| 212 |
+
if os.path.exists(out_path):
|
| 213 |
+
os.remove(out_path)
|
| 214 |
+
print(f"Removed existing file {out_path} to avoid overwrite issues.")
|
| 215 |
+
|
| 216 |
+
with h5py.File(out_path, "w") as h5f:
|
| 217 |
+
for group_idx, start in enumerate(range(0, X.shape[0], args.group_size)):
|
| 218 |
+
end = min(start + args.group_size, X.shape[0])
|
| 219 |
+
x_chunk = X[start:end].astype(np.float32)
|
| 220 |
+
y_chunk = y[start:end].astype(np.int64)
|
| 221 |
+
|
| 222 |
+
grp = h5f.create_group(f"data_group_{group_idx}")
|
| 223 |
+
grp.create_dataset("X", data=x_chunk)
|
| 224 |
+
grp.create_dataset("y", data=y_chunk)
|
| 225 |
|
| 226 |
uniq, cnt = np.unique(y, return_counts=True)
|
| 227 |
print(f"\n{split.upper()} β X={X.shape}, label distribution:")
|
scripts/db8.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
| 3 |
|
| 4 |
import h5py
|
| 5 |
import numpy as np
|
|
@@ -9,7 +10,18 @@ from joblib import Parallel, delayed
|
|
| 9 |
from scipy.signal import iirnotch
|
| 10 |
from tqdm import tqdm
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
_MATRIX_DOF2DOA_TRANSPOSED = np.array(
|
| 15 |
# https://www.frontiersin.org/articles/10.3389/fnins.2019.00891/full
|
|
@@ -42,9 +54,18 @@ _MATRIX_DOF2DOA_TRANSPOSED = np.array(
|
|
| 42 |
MATRIX_DOF2DOA = _MATRIX_DOF2DOA_TRANSPOSED.T
|
| 43 |
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
b, a = iirnotch(notch_freq, Q, fs)
|
| 49 |
out = np.zeros_like(data)
|
| 50 |
for ch in range(data.shape[1]):
|
|
@@ -52,7 +73,25 @@ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=1111.0):
|
|
| 52 |
return out
|
| 53 |
|
| 54 |
|
| 55 |
-
def bandpass_filter_emg(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
nyq = 0.5 * fs
|
| 57 |
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 58 |
out = np.zeros_like(emg)
|
|
@@ -61,11 +100,24 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
|
| 61 |
return out
|
| 62 |
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
"""
|
| 70 |
segments, labels = [], []
|
| 71 |
n_samples = len(label)
|
|
@@ -80,34 +132,49 @@ def sliding_window_segment(emg, label, window_size, stride):
|
|
| 80 |
return np.array(segments), np.array(labels)
|
| 81 |
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
mat = scipy.io.loadmat(mat_path)
|
| 90 |
emg = mat["emg"] # (T, 16)
|
| 91 |
label = mat["glove"] # (T, DoF)
|
| 92 |
|
| 93 |
-
#
|
| 94 |
valid = ~np.isnan(label).any(axis=1)
|
| 95 |
emg = emg[valid]
|
| 96 |
label = label[valid]
|
| 97 |
|
| 98 |
-
#
|
| 99 |
mu = emg.mean(axis=0)
|
| 100 |
sd = emg.std(axis=0, ddof=1)
|
| 101 |
sd[sd == 0] = 1.0
|
| 102 |
emg = (emg - mu) / sd
|
| 103 |
|
| 104 |
-
#
|
| 105 |
y_doa = (MATRIX_DOF2DOA @ label.T).T
|
| 106 |
|
| 107 |
-
#
|
| 108 |
segs, labs = sliding_window_segment(emg, y_doa, window_size, stride)
|
| 109 |
|
| 110 |
-
#
|
| 111 |
fname = os.path.basename(mat_path)
|
| 112 |
if "_A1" in fname:
|
| 113 |
split = "train"
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
from typing import Tuple, List, Optional, Union, Dict, Any
|
| 4 |
|
| 5 |
import h5py
|
| 6 |
import numpy as np
|
|
|
|
| 10 |
from scipy.signal import iirnotch
|
| 11 |
from tqdm import tqdm
|
| 12 |
|
| 13 |
+
def sequence_to_seconds(seq_len: int, fs: float) -> float:
|
| 14 |
+
"""Converts a sequence length in samples to time in seconds.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
seq_len (int): The number of samples in the sequence.
|
| 18 |
+
fs (float): The sampling frequency in Hz.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
float: The duration of the sequence in seconds.
|
| 22 |
+
"""
|
| 23 |
+
return seq_len / fs
|
| 24 |
+
|
| 25 |
|
| 26 |
_MATRIX_DOF2DOA_TRANSPOSED = np.array(
|
| 27 |
# https://www.frontiersin.org/articles/10.3389/fnins.2019.00891/full
|
|
|
|
| 54 |
MATRIX_DOF2DOA = _MATRIX_DOF2DOA_TRANSPOSED.T
|
| 55 |
|
| 56 |
|
| 57 |
+
def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 1111.0) -> np.ndarray:
|
| 58 |
+
"""Applies a notch filter to every channel of the input data independently.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
data (np.ndarray): The input signal array of shape (T, D).
|
| 62 |
+
notch_freq (float, optional): The frequency to be removed. Defaults to 50.0.
|
| 63 |
+
Q (float, optional): The quality factor. Defaults to 30.0.
|
| 64 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 1111.0.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
np.ndarray: The filtered signal array.
|
| 68 |
+
"""
|
| 69 |
b, a = iirnotch(notch_freq, Q, fs)
|
| 70 |
out = np.zeros_like(data)
|
| 71 |
for ch in range(data.shape[1]):
|
|
|
|
| 73 |
return out
|
| 74 |
|
| 75 |
|
| 76 |
+
def bandpass_filter_emg(
|
| 77 |
+
emg: np.ndarray,
|
| 78 |
+
lowcut: float = 20.0,
|
| 79 |
+
highcut: float = 90.0,
|
| 80 |
+
fs: float = 2000.0,
|
| 81 |
+
order: int = 4
|
| 82 |
+
) -> np.ndarray:
|
| 83 |
+
"""Applies a Butterworth bandpass filter to the EMG signal.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
emg (np.ndarray): The input signal array of shape (T, D).
|
| 87 |
+
lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
|
| 88 |
+
highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
|
| 89 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0.
|
| 90 |
+
order (int, optional): The order of the filter. Defaults to 4.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
np.ndarray: The filtered signal array.
|
| 94 |
+
"""
|
| 95 |
nyq = 0.5 * fs
|
| 96 |
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 97 |
out = np.zeros_like(emg)
|
|
|
|
| 100 |
return out
|
| 101 |
|
| 102 |
|
| 103 |
+
def sliding_window_segment(
|
| 104 |
+
emg: np.ndarray,
|
| 105 |
+
label: np.ndarray,
|
| 106 |
+
window_size: int,
|
| 107 |
+
stride: int
|
| 108 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 109 |
+
"""Segments EMG and label data using a sliding window.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
emg (np.ndarray): The raw EMG data of shape (T, n_ch).
|
| 113 |
+
label (np.ndarray): The corresponding labels/targets.
|
| 114 |
+
window_size (int): Number of samples per window.
|
| 115 |
+
stride (int): Number of samples to shift between windows.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Tuple[np.ndarray, np.ndarray]: A tuple containing:
|
| 119 |
+
- segmented EMG tokens (N, window_size, n_ch).
|
| 120 |
+
- segmented label tokens (N, window_size, target_dim).
|
| 121 |
"""
|
| 122 |
segments, labels = [], []
|
| 123 |
n_samples = len(label)
|
|
|
|
| 132 |
return np.array(segments), np.array(labels)
|
| 133 |
|
| 134 |
|
| 135 |
+
def process_mat_file(
|
| 136 |
+
mat_path: str,
|
| 137 |
+
window_size: int,
|
| 138 |
+
stride: int,
|
| 139 |
+
fs: float
|
| 140 |
+
) -> Optional[Tuple[str, np.ndarray, np.ndarray]]:
|
| 141 |
+
"""Processes a single NinaPro DB8 .mat file.
|
| 142 |
+
|
| 143 |
+
Loads the file, removes NaNs, normalizes EMG (Z-score), maps finger degrees
|
| 144 |
+
of freedom (DoF) to degrees of activation (DoA), and segments the data.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
mat_path (str): Absolute path to the .mat file.
|
| 148 |
+
window_size (int): Temporal window size in samples.
|
| 149 |
+
stride (int): Stride between windows in samples.
|
| 150 |
+
fs (float): Sampling frequency in Hz.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Optional[Tuple[str, np.ndarray, np.ndarray]]: A tuple of (split_name, segments, labels)
|
| 154 |
+
if the file is valid, else None.
|
| 155 |
"""
|
| 156 |
mat = scipy.io.loadmat(mat_path)
|
| 157 |
emg = mat["emg"] # (T, 16)
|
| 158 |
label = mat["glove"] # (T, DoF)
|
| 159 |
|
| 160 |
+
# Drop timesteps with any NaNs in glove data
|
| 161 |
valid = ~np.isnan(label).any(axis=1)
|
| 162 |
emg = emg[valid]
|
| 163 |
label = label[valid]
|
| 164 |
|
| 165 |
+
# Z-score per channel
|
| 166 |
mu = emg.mean(axis=0)
|
| 167 |
sd = emg.std(axis=0, ddof=1)
|
| 168 |
sd[sd == 0] = 1.0
|
| 169 |
emg = (emg - mu) / sd
|
| 170 |
|
| 171 |
+
# DoF β DoA
|
| 172 |
y_doa = (MATRIX_DOF2DOA @ label.T).T
|
| 173 |
|
| 174 |
+
# Windowing
|
| 175 |
segs, labs = sliding_window_segment(emg, y_doa, window_size, stride)
|
| 176 |
|
| 177 |
+
# Determine split
|
| 178 |
fname = os.path.basename(mat_path)
|
| 179 |
if "_A1" in fname:
|
| 180 |
split = "train"
|
scripts/emg2pose.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from pathlib import Path
|
|
|
|
| 3 |
|
| 4 |
import h5py
|
| 5 |
import numpy as np
|
|
@@ -9,11 +11,31 @@ from joblib import Parallel, delayed
|
|
| 9 |
from scipy.signal import iirnotch
|
| 10 |
from tqdm import tqdm
|
| 11 |
|
| 12 |
-
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
b, a = iirnotch(notch_freq, Q, fs)
|
| 18 |
out = np.zeros_like(data)
|
| 19 |
for ch in range(data.shape[1]):
|
|
@@ -21,7 +43,25 @@ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
|
| 21 |
return out
|
| 22 |
|
| 23 |
|
| 24 |
-
def bandpass_filter_emg(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
nyq = 0.5 * fs
|
| 26 |
low = lowcut / nyq
|
| 27 |
high = highcut / nyq
|
|
@@ -32,9 +72,18 @@ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
|
| 32 |
return out
|
| 33 |
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
N = len(emg)
|
| 39 |
for start in range(0, N, stride):
|
| 40 |
end = start + window_size
|
|
@@ -45,10 +94,19 @@ def process_emg_features(emg, window_size=1000, stride=500):
|
|
| 45 |
return np.array(segs)
|
| 46 |
|
| 47 |
|
| 48 |
-
def process_one_recording(file_path, fs=2000.0, window_size=1000, stride=500):
|
| 49 |
-
"""
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"""
|
| 53 |
with h5py.File(file_path, "r") as f:
|
| 54 |
grp = f["emg2pose"]
|
|
@@ -71,7 +129,6 @@ def process_one_recording(file_path, fs=2000.0, window_size=1000, stride=500):
|
|
| 71 |
return segs
|
| 72 |
|
| 73 |
|
| 74 |
-
# ==== Main pipeline ====
|
| 75 |
def main():
|
| 76 |
import argparse
|
| 77 |
|
|
@@ -93,6 +150,12 @@ def main():
|
|
| 93 |
default=-1,
|
| 94 |
help="Number of parallel jobs to run. -1 means using all available cores.",
|
| 95 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
args.add_argument(
|
| 97 |
"--seed", type=int, default=42, help="Random seed for reproducibility."
|
| 98 |
)
|
|
@@ -109,45 +172,62 @@ def main():
|
|
| 109 |
print(f"Window size: {window_size} samples ({window_seconds:.2f} seconds)")
|
| 110 |
|
| 111 |
df = pd.read_csv(os.path.join(data_dir, "metadata.csv"))
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
if args.subsample < 1.0
|
| 116 |
-
else x
|
| 117 |
)
|
| 118 |
-
)
|
| 119 |
-
df.reset_index(drop=True, inplace=True)
|
| 120 |
|
| 121 |
splits = {}
|
| 122 |
for split, df_ in df.groupby("split"):
|
| 123 |
sessions = list(df_.filename)
|
| 124 |
-
splits[split] =
|
| 125 |
Path(data_dir).expanduser().joinpath(f"{session}.hdf5")
|
| 126 |
for session in sessions
|
| 127 |
]
|
| 128 |
|
| 129 |
-
all_data = {"train": [], "val": [], "test": []}
|
| 130 |
-
|
| 131 |
for split, files in splits.items():
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
-
|
| 148 |
-
with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf:
|
| 149 |
-
hf.create_dataset("data", data=X)
|
| 150 |
|
| 151 |
|
| 152 |
if __name__ == "__main__":
|
| 153 |
-
main()
|
|
|
|
| 1 |
import os
|
| 2 |
+
import gc
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Tuple, List, Optional, Union, Dict, Any
|
| 5 |
|
| 6 |
import h5py
|
| 7 |
import numpy as np
|
|
|
|
| 11 |
from scipy.signal import iirnotch
|
| 12 |
from tqdm import tqdm
|
| 13 |
|
| 14 |
+
def sequence_to_seconds(seq_len: int, fs: float) -> float:
|
| 15 |
+
"""Converts a sequence length in samples to time in seconds.
|
| 16 |
|
| 17 |
+
Args:
|
| 18 |
+
seq_len (int): The number of samples in the sequence.
|
| 19 |
+
fs (float): The sampling frequency in Hz.
|
| 20 |
|
| 21 |
+
Returns:
|
| 22 |
+
float: The duration of the sequence in seconds.
|
| 23 |
+
"""
|
| 24 |
+
return seq_len / fs
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def notch_filter(data: np.ndarray, notch_freq: float = 50.0, Q: float = 30.0, fs: float = 2000.0) -> np.ndarray:
|
| 28 |
+
"""Applies a notch filter to every channel of the input data independently.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
data (np.ndarray): The input signal array of shape (T, D).
|
| 32 |
+
notch_freq (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
|
| 33 |
+
Q (float, optional): The quality factor. Defaults to 30.0.
|
| 34 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
np.ndarray: The filtered signal array.
|
| 38 |
+
"""
|
| 39 |
b, a = iirnotch(notch_freq, Q, fs)
|
| 40 |
out = np.zeros_like(data)
|
| 41 |
for ch in range(data.shape[1]):
|
|
|
|
| 43 |
return out
|
| 44 |
|
| 45 |
|
| 46 |
+
def bandpass_filter_emg(
|
| 47 |
+
emg: np.ndarray,
|
| 48 |
+
lowcut: float = 20.0,
|
| 49 |
+
highcut: float = 90.0,
|
| 50 |
+
fs: float = 2000.0,
|
| 51 |
+
order: int = 4
|
| 52 |
+
) -> np.ndarray:
|
| 53 |
+
"""Applies a Butterworth bandpass filter to the EMG signal.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
emg (np.ndarray): The input signal array of shape (T, D).
|
| 57 |
+
lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
|
| 58 |
+
highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
|
| 59 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 2000.0.
|
| 60 |
+
order (int, optional): The order of the filter. Defaults to 4.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
np.ndarray: The filtered signal array.
|
| 64 |
+
"""
|
| 65 |
nyq = 0.5 * fs
|
| 66 |
low = lowcut / nyq
|
| 67 |
high = highcut / nyq
|
|
|
|
| 72 |
return out
|
| 73 |
|
| 74 |
|
| 75 |
+
def process_emg_features(emg: np.ndarray, window_size: int = 1000, stride: int = 500) -> np.ndarray:
|
| 76 |
+
"""Segments raw EMG signals into overlapping windows.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
emg (np.ndarray): Raw EMG data of shape (T, n_ch).
|
| 80 |
+
window_size (int, optional): Number of samples per window. Defaults to 1000.
|
| 81 |
+
stride (int, optional): Number of samples to shift between windows. Defaults to 500.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
np.ndarray: Segmented data of shape (N, window_size, n_ch).
|
| 85 |
+
"""
|
| 86 |
+
segs = []
|
| 87 |
N = len(emg)
|
| 88 |
for start in range(0, N, stride):
|
| 89 |
end = start + window_size
|
|
|
|
| 94 |
return np.array(segs)
|
| 95 |
|
| 96 |
|
| 97 |
+
def process_one_recording(file_path: str, fs: float = 2000.0, window_size: int = 1000, stride: int = 500) -> np.ndarray:
|
| 98 |
+
"""Processes a single EMG2Pose recording file.
|
| 99 |
+
|
| 100 |
+
Loads HDF5 timeseries, filters EMG, normalizes (Z-score), and segments.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
file_path (str): Absolute path to the .h5 recording file.
|
| 104 |
+
fs (float, optional): Sampling frequency in Hz. Defaults to 2000.0.
|
| 105 |
+
window_size (int, optional): Temporal window size in samples. Defaults to 1000.
|
| 106 |
+
stride (int, optional): Stride between windows in samples. Defaults to 500.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
np.ndarray: Array of processed segments (N, window_size, n_ch).
|
| 110 |
"""
|
| 111 |
with h5py.File(file_path, "r") as f:
|
| 112 |
grp = f["emg2pose"]
|
|
|
|
| 129 |
return segs
|
| 130 |
|
| 131 |
|
|
|
|
| 132 |
def main():
|
| 133 |
import argparse
|
| 134 |
|
|
|
|
| 150 |
default=-1,
|
| 151 |
help="Number of parallel jobs to run. -1 means using all available cores.",
|
| 152 |
)
|
| 153 |
+
args.add_argument(
|
| 154 |
+
"--group_size",
|
| 155 |
+
type=int,
|
| 156 |
+
default=1000,
|
| 157 |
+
help="Number of samples per group in the output HDF5 file.",
|
| 158 |
+
)
|
| 159 |
args.add_argument(
|
| 160 |
"--seed", type=int, default=42, help="Random seed for reproducibility."
|
| 161 |
)
|
|
|
|
| 172 |
print(f"Window size: {window_size} samples ({window_seconds:.2f} seconds)")
|
| 173 |
|
| 174 |
df = pd.read_csv(os.path.join(data_dir, "metadata.csv"))
|
| 175 |
+
if args.subsample < 1.0:
|
| 176 |
+
df = df.groupby("split", group_keys=False).sample(
|
| 177 |
+
frac=args.subsample, random_state=args.seed
|
|
|
|
|
|
|
| 178 |
)
|
| 179 |
+
df = df.reset_index(drop=True)
|
|
|
|
| 180 |
|
| 181 |
splits = {}
|
| 182 |
for split, df_ in df.groupby("split"):
|
| 183 |
sessions = list(df_.filename)
|
| 184 |
+
splits[split] =[
|
| 185 |
Path(data_dir).expanduser().joinpath(f"{session}.hdf5")
|
| 186 |
for session in sessions
|
| 187 |
]
|
| 188 |
|
|
|
|
|
|
|
| 189 |
for split, files in splits.items():
|
| 190 |
+
out_file = os.path.join(save_dir, f"{split}.h5")
|
| 191 |
+
|
| 192 |
+
# Remove existing file if it exists so we don't accidentally append to old runs
|
| 193 |
+
if os.path.exists(out_file):
|
| 194 |
+
os.remove(out_file)
|
| 195 |
+
|
| 196 |
+
print(f"Processing {split} split ({len(files)} files)...")
|
| 197 |
+
|
| 198 |
+
with h5py.File(out_file, "w") as h5f:
|
| 199 |
+
group_idx = 0
|
| 200 |
+
with Parallel(n_jobs=args.n_jobs) as parallel:
|
| 201 |
+
with tqdm(total=len(files), desc=f"Processing & Saving {split}") as pbar:
|
| 202 |
+
|
| 203 |
+
# Iterate files in batches
|
| 204 |
+
for i in range(0, len(files), args.group_size):
|
| 205 |
+
batch_files = files[i : i + args.group_size]
|
| 206 |
+
|
| 207 |
+
# Process current batch
|
| 208 |
+
results = parallel(
|
| 209 |
+
delayed(process_one_recording)(file_path, fs, window_size, stride)
|
| 210 |
+
for file_path in batch_files
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if results:
|
| 214 |
+
X_chunk = np.concatenate(results, axis=0) # [N, window_size, ch]
|
| 215 |
+
X_chunk = X_chunk.transpose(0, 2, 1) # [N, ch, window_size]
|
| 216 |
+
X_chunk = X_chunk.astype(np.float32)
|
| 217 |
|
| 218 |
+
# Write each processed batch as a group compatible with HDF5Loader
|
| 219 |
+
grp = h5f.create_group(f"data_group_{group_idx}")
|
| 220 |
+
grp.create_dataset("X", data=X_chunk)
|
| 221 |
+
group_idx += 1
|
| 222 |
|
| 223 |
+
# Explicitly clear memory of large numpy arrays
|
| 224 |
+
del results
|
| 225 |
+
if 'X_chunk' in locals():
|
| 226 |
+
del X_chunk
|
| 227 |
+
gc.collect()
|
| 228 |
|
| 229 |
+
pbar.update(len(batch_files))
|
|
|
|
|
|
|
| 230 |
|
| 231 |
|
| 232 |
if __name__ == "__main__":
|
| 233 |
+
main()
|
scripts/epn.py
CHANGED
|
@@ -2,6 +2,7 @@ import glob
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import sys
|
|
|
|
| 5 |
|
| 6 |
import h5py
|
| 7 |
import numpy as np
|
|
@@ -10,7 +11,17 @@ from joblib import Parallel, delayed
|
|
| 10 |
from scipy.signal import iirnotch
|
| 11 |
from tqdm.auto import tqdm
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Sampling frequency and EMG channels
|
| 16 |
tfs, n_ch = 200.0, 8
|
|
@@ -27,28 +38,77 @@ gesture_map = {
|
|
| 27 |
}
|
| 28 |
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
nyq = 0.5 * fs
|
| 33 |
b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
|
| 34 |
return signal.filtfilt(b, a, emg, axis=1)
|
| 35 |
|
| 36 |
|
| 37 |
-
def notch_filter_emg(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
w0 = notch / (0.5 * fs)
|
| 39 |
b, a = iirnotch(w0, Q)
|
| 40 |
return signal.filtfilt(b, a, emg, axis=1)
|
| 41 |
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
mean = emg.mean(axis=1, keepdims=True)
|
| 46 |
std = emg.std(axis=1, ddof=1, keepdims=True)
|
| 47 |
std[std == 0] = 1.0
|
| 48 |
return (emg - mean) / std
|
| 49 |
|
| 50 |
|
| 51 |
-
def adjust_length(x, max_len):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
n_ch, seq_len = x.shape
|
| 53 |
if seq_len >= max_len:
|
| 54 |
return x[:, :max_len]
|
|
@@ -56,8 +116,18 @@ def adjust_length(x, max_len):
|
|
| 56 |
return np.concatenate([x, pad], axis=1)
|
| 57 |
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
|
| 62 |
emg = bandpass_filter_emg(emg, 20.0, 90.0)
|
| 63 |
emg = notch_filter_emg(emg, 50.0, 30.0)
|
|
@@ -67,8 +137,20 @@ def extract_emg_signal(sample, seq_len):
|
|
| 67 |
return emg, label
|
| 68 |
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
train_X, train_y, val_X, val_y = [], [], [], []
|
| 73 |
with open(path, "r", encoding="utf-8") as f:
|
| 74 |
data = json.load(f)
|
|
@@ -79,14 +161,29 @@ def process_user_training(path, seq_len):
|
|
| 79 |
train_y.append(lbl)
|
| 80 |
for sample in data.get("testingSamples", {}).values():
|
| 81 |
emg, lbl = extract_emg_signal(sample, seq_len)
|
|
|
|
|
|
|
|
|
|
| 82 |
if lbl != 6:
|
| 83 |
val_X.append(emg)
|
| 84 |
val_y.append(lbl)
|
| 85 |
return train_X, train_y, val_X, val_y
|
| 86 |
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
train_X, train_y, test_X, test_y = [], [], [], []
|
| 91 |
with open(path, "r", encoding="utf-8") as f:
|
| 92 |
data = json.load(f)
|
|
@@ -107,14 +204,19 @@ def process_user_testing(path, seq_len):
|
|
| 107 |
return train_X, train_y, test_X, test_y
|
| 108 |
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
with h5py.File(path, "w") as f:
|
| 113 |
f.create_dataset("data", data=np.asarray(data, np.float32))
|
| 114 |
f.create_dataset("label", data=np.asarray(labels, np.int64))
|
| 115 |
|
| 116 |
|
| 117 |
-
# Main parallelized pipeline
|
| 118 |
def main():
|
| 119 |
import argparse
|
| 120 |
|
|
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
+
from typing import Tuple, List, Optional, Union, Dict, Any
|
| 6 |
|
| 7 |
import h5py
|
| 8 |
import numpy as np
|
|
|
|
| 11 |
from scipy.signal import iirnotch
|
| 12 |
from tqdm.auto import tqdm
|
| 13 |
|
| 14 |
+
def sequence_to_seconds(seq_len: int, fs: float) -> float:
|
| 15 |
+
"""Converts a sequence length in samples to time in seconds.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
seq_len (int): The number of samples in the sequence.
|
| 19 |
+
fs (float): The sampling frequency in Hz.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
float: The duration of the sequence in seconds.
|
| 23 |
+
"""
|
| 24 |
+
return seq_len / fs
|
| 25 |
|
| 26 |
# Sampling frequency and EMG channels
|
| 27 |
tfs, n_ch = 200.0, 8
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
|
| 41 |
+
def bandpass_filter_emg(
|
| 42 |
+
emg: np.ndarray,
|
| 43 |
+
low: float = 20.0,
|
| 44 |
+
high: float = 90.0,
|
| 45 |
+
fs: float = tfs,
|
| 46 |
+
order: int = 4
|
| 47 |
+
) -> np.ndarray:
|
| 48 |
+
"""Applies a Butterworth bandpass filter to the EMG signal.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
emg (np.ndarray): The input signal array of shape (n_ch, T).
|
| 52 |
+
low (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
|
| 53 |
+
high (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
|
| 54 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
|
| 55 |
+
order (int, optional): The order of the filter. Defaults to 4.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
np.ndarray: The filtered signal array.
|
| 59 |
+
"""
|
| 60 |
nyq = 0.5 * fs
|
| 61 |
b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
|
| 62 |
return signal.filtfilt(b, a, emg, axis=1)
|
| 63 |
|
| 64 |
|
| 65 |
+
def notch_filter_emg(
|
| 66 |
+
emg: np.ndarray,
|
| 67 |
+
notch: float = 50.0,
|
| 68 |
+
Q: float = 30.0,
|
| 69 |
+
fs: float = tfs
|
| 70 |
+
) -> np.ndarray:
|
| 71 |
+
"""Applies a notch filter to remove power line interference.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
emg (np.ndarray): The input signal array of shape (n_ch, T).
|
| 75 |
+
notch (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
|
| 76 |
+
Q (float, optional): The quality factor. Defaults to 30.0.
|
| 77 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
np.ndarray: The filtered signal array.
|
| 81 |
+
"""
|
| 82 |
w0 = notch / (0.5 * fs)
|
| 83 |
b, a = iirnotch(w0, Q)
|
| 84 |
return signal.filtfilt(b, a, emg, axis=1)
|
| 85 |
|
| 86 |
|
| 87 |
+
def zscore_per_channel(emg: np.ndarray) -> np.ndarray:
|
| 88 |
+
"""Normalizes the EMG signal using Z-score (per channel).
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
emg (np.ndarray): The input EMG signal of shape (n_ch, T).
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
np.ndarray: The normalized EMG signal.
|
| 95 |
+
"""
|
| 96 |
mean = emg.mean(axis=1, keepdims=True)
|
| 97 |
std = emg.std(axis=1, ddof=1, keepdims=True)
|
| 98 |
std[std == 0] = 1.0
|
| 99 |
return (emg - mean) / std
|
| 100 |
|
| 101 |
|
| 102 |
+
def adjust_length(x: np.ndarray, max_len: int) -> np.ndarray:
|
| 103 |
+
"""Standardizes the temporal length of the signal by clipping or zero-padding.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
x (np.ndarray): The input signal of shape (n_ch, T).
|
| 107 |
+
max_len (int): The target length in samples.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
np.ndarray: The standardized length signal of shape (n_ch, max_len).
|
| 111 |
+
"""
|
| 112 |
n_ch, seq_len = x.shape
|
| 113 |
if seq_len >= max_len:
|
| 114 |
return x[:, :max_len]
|
|
|
|
| 116 |
return np.concatenate([x, pad], axis=1)
|
| 117 |
|
| 118 |
|
| 119 |
+
def extract_emg_signal(sample: Dict[str, Any], seq_len: int) -> Tuple[np.ndarray, int]:
|
| 120 |
+
"""Extracts, filters, and normalizes EMG data from a JSON sample.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
sample (Dict[str, Any]): A single sample dictionary from the EPN612 JSON.
|
| 124 |
+
seq_len (int): Target temporal length.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Tuple[np.ndarray, int]: A tuple containing:
|
| 128 |
+
- The preprocessed EMG signal (n_ch, seq_len).
|
| 129 |
+
- The gesture label ID.
|
| 130 |
+
"""
|
| 131 |
emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
|
| 132 |
emg = bandpass_filter_emg(emg, 20.0, 90.0)
|
| 133 |
emg = notch_filter_emg(emg, 50.0, 30.0)
|
|
|
|
| 137 |
return emg, label
|
| 138 |
|
| 139 |
|
| 140 |
+
def process_user_training(
|
| 141 |
+
path: str,
|
| 142 |
+
seq_len: int
|
| 143 |
+
) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| 144 |
+
"""Processes a user's training JSON file for the training and validation splits.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
path (str): Path to the user JSON file.
|
| 148 |
+
seq_len (int): Target temporal length for segmentation.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| 152 |
+
(train_X, train_y, val_X, val_y) lists.
|
| 153 |
+
"""
|
| 154 |
train_X, train_y, val_X, val_y = [], [], [], []
|
| 155 |
with open(path, "r", encoding="utf-8") as f:
|
| 156 |
data = json.load(f)
|
|
|
|
| 161 |
train_y.append(lbl)
|
| 162 |
for sample in data.get("testingSamples", {}).values():
|
| 163 |
emg, lbl = extract_emg_signal(sample, seq_len)
|
| 164 |
+
if lbl != 10: # Assuming 10 was the intention or checking if not invalid
|
| 165 |
+
pass
|
| 166 |
+
# Note: checking lbl != 6 as in original
|
| 167 |
if lbl != 6:
|
| 168 |
val_X.append(emg)
|
| 169 |
val_y.append(lbl)
|
| 170 |
return train_X, train_y, val_X, val_y
|
| 171 |
|
| 172 |
|
| 173 |
+
def process_user_testing(
|
| 174 |
+
path: str,
|
| 175 |
+
seq_len: int
|
| 176 |
+
) -> Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| 177 |
+
"""Processes a user's testing JSON file for the fine-tuning and test splits.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
path (str): Path to the user JSON file.
|
| 181 |
+
seq_len (int): Target temporal length for segmentation.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple[List[np.ndarray], List[int], List[np.ndarray], List[int]]:
|
| 185 |
+
(tune_X, tune_y, test_X, test_y) lists.
|
| 186 |
+
"""
|
| 187 |
train_X, train_y, test_X, test_y = [], [], [], []
|
| 188 |
with open(path, "r", encoding="utf-8") as f:
|
| 189 |
data = json.load(f)
|
|
|
|
| 204 |
return train_X, train_y, test_X, test_y
|
| 205 |
|
| 206 |
|
| 207 |
+
def save_h5(path: str, data: List[np.ndarray], labels: List[int]) -> None:
|
| 208 |
+
"""Saves the processed EMG data and labels to an HDF5 file.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
path (str): Output file path.
|
| 212 |
+
data (List[np.ndarray]): List of signal segments.
|
| 213 |
+
labels (List[int]): List of categorical labels.
|
| 214 |
+
"""
|
| 215 |
with h5py.File(path, "w") as f:
|
| 216 |
f.create_dataset("data", data=np.asarray(data, np.float32))
|
| 217 |
f.create_dataset("label", data=np.asarray(labels, np.int64))
|
| 218 |
|
| 219 |
|
|
|
|
| 220 |
def main():
|
| 221 |
import argparse
|
| 222 |
|
scripts/uci.py
CHANGED
|
@@ -1,36 +1,78 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
from pathlib import Path
|
|
|
|
| 4 |
|
| 5 |
import h5py
|
| 6 |
import numpy as np
|
| 7 |
import scipy.signal as signal
|
| 8 |
from scipy.signal import iirnotch
|
| 9 |
|
| 10 |
-
|
|
|
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
nyq = 0.5 * fs
|
| 18 |
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 19 |
return signal.filtfilt(b, a, emg, axis=0)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
b, a = iirnotch(notch_freq / (0.5 * fs), Q)
|
| 24 |
return signal.filtfilt(b, a, emg, axis=0)
|
| 25 |
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
"""
|
| 35 |
data = []
|
| 36 |
with open(txt_path, "r") as f:
|
|
@@ -41,10 +83,22 @@ def read_emg_txt(txt_path):
|
|
| 41 |
return np.asarray(data, dtype=np.float32)
|
| 42 |
|
| 43 |
|
| 44 |
-
def preprocess_emg(arr, fs=200.0, remove_class0=True):
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
"""
|
| 49 |
if remove_class0:
|
| 50 |
arr = arr[arr[:, -1] >= 1]
|
|
@@ -64,8 +118,15 @@ def preprocess_emg(arr, fs=200.0, remove_class0=True):
|
|
| 64 |
return arr
|
| 65 |
|
| 66 |
|
| 67 |
-
def find_label_runs(arr):
|
| 68 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
runs = []
|
| 70 |
if arr.size == 0:
|
| 71 |
return runs
|
|
@@ -80,7 +141,23 @@ def find_label_runs(arr):
|
|
| 80 |
return runs
|
| 81 |
|
| 82 |
|
| 83 |
-
def sliding_window_majority(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
segs, labs = [], []
|
| 85 |
for start in range(0, len(seg_arr) - window_size + 1, stride):
|
| 86 |
win = seg_arr[start : start + window_size]
|
|
@@ -91,8 +168,24 @@ def sliding_window_majority(seg_arr, window_size=1000, stride=500):
|
|
| 91 |
|
| 92 |
|
| 93 |
def users_with_gesture(
|
| 94 |
-
data_root
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
found = {}
|
| 97 |
for subj in subj_range:
|
| 98 |
subj_dir = os.path.join(data_root, f"{subj:02d}")
|
|
@@ -122,21 +215,28 @@ def users_with_gesture(
|
|
| 122 |
else:
|
| 123 |
return sorted(found.keys())
|
| 124 |
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
|
| 131 |
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
-
# Main
|
| 139 |
-
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
if __name__ == "__main__":
|
| 141 |
import argparse
|
| 142 |
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Tuple, List, Optional, Union, Dict, Any
|
| 5 |
|
| 6 |
import h5py
|
| 7 |
import numpy as np
|
| 8 |
import scipy.signal as signal
|
| 9 |
from scipy.signal import iirnotch
|
| 10 |
|
| 11 |
+
def sequence_to_seconds(seq_len: int, fs: float) -> float:
|
| 12 |
+
"""Converts a sequence length in samples to time in seconds.
|
| 13 |
|
| 14 |
+
Args:
|
| 15 |
+
seq_len (int): The number of samples in the sequence.
|
| 16 |
+
fs (float): The sampling frequency in Hz.
|
| 17 |
|
| 18 |
+
Returns:
|
| 19 |
+
float: The duration of the sequence in seconds.
|
| 20 |
+
"""
|
| 21 |
+
return seq_len / fs
|
| 22 |
+
|
| 23 |
+
def bandpass_filter_emg(
|
| 24 |
+
emg: np.ndarray,
|
| 25 |
+
lowcut: float = 20.0,
|
| 26 |
+
highcut: float = 90.0,
|
| 27 |
+
fs: float = 200.0,
|
| 28 |
+
order: int = 4
|
| 29 |
+
) -> np.ndarray:
|
| 30 |
+
"""Applies a Butterworth bandpass filter to the EMG signal.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
emg (np.ndarray): The input signal array of shape (T, D).
|
| 34 |
+
lowcut (float, optional): Lower bound of the passband in Hz. Defaults to 20.0.
|
| 35 |
+
highcut (float, optional): Upper bound of the passband in Hz. Defaults to 90.0.
|
| 36 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
|
| 37 |
+
order (int, optional): The order of the filter. Defaults to 4.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
np.ndarray: The filtered signal array.
|
| 41 |
+
"""
|
| 42 |
nyq = 0.5 * fs
|
| 43 |
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 44 |
return signal.filtfilt(b, a, emg, axis=0)
|
| 45 |
|
| 46 |
+
def notch_filter_emg(
|
| 47 |
+
emg: np.ndarray,
|
| 48 |
+
notch_freq: float = 50.0,
|
| 49 |
+
Q: float = 30.0,
|
| 50 |
+
fs: float = 200.0
|
| 51 |
+
) -> np.ndarray:
|
| 52 |
+
"""Applies a notch filter to remove power line interference.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
emg (np.ndarray): The input signal array of shape (T, D).
|
| 56 |
+
notch_freq (float, optional): The frequency to be removed in Hz. Defaults to 50.0.
|
| 57 |
+
Q (float, optional): The quality factor. Defaults to 30.0.
|
| 58 |
+
fs (float, optional): The sampling frequency in Hz. Defaults to 200.0.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
np.ndarray: The filtered signal array.
|
| 62 |
+
"""
|
| 63 |
b, a = iirnotch(notch_freq / (0.5 * fs), Q)
|
| 64 |
return signal.filtfilt(b, a, emg, axis=0)
|
| 65 |
|
| 66 |
+
def read_emg_txt(txt_path: str) -> np.ndarray:
|
| 67 |
+
"""Reads a UCI EMG text file into a numpy array.
|
| 68 |
|
| 69 |
+
The file is expected to have columns: [time, ch1, ..., ch8, class].
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
txt_path (str): Path to the .txt file.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
np.ndarray: A float32 array of shape (N, 10).
|
| 76 |
"""
|
| 77 |
data = []
|
| 78 |
with open(txt_path, "r") as f:
|
|
|
|
| 83 |
return np.asarray(data, dtype=np.float32)
|
| 84 |
|
| 85 |
|
| 86 |
+
def preprocess_emg(arr: np.ndarray, fs: float = 200.0, remove_class0: bool = True) -> np.ndarray:
|
| 87 |
+
"""Applies a standard preprocessing pipeline to the EMG data.
|
| 88 |
+
|
| 89 |
+
Pipeline includes:
|
| 90 |
+
1. Optional removal of rest (class 0).
|
| 91 |
+
2. Bandpass filtering (20-90 Hz).
|
| 92 |
+
3. Notch filtering (50 Hz).
|
| 93 |
+
4. Z-score normalization per channel.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
arr (np.ndarray): Raw data array of shape (N, 10).
|
| 97 |
+
fs (float, optional): Sampling frequency in Hz. Defaults to 200.0.
|
| 98 |
+
remove_class0 (bool, optional): Whether to remove the "rest" class. Defaults to True.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
np.ndarray: The preprocessed data array.
|
| 102 |
"""
|
| 103 |
if remove_class0:
|
| 104 |
arr = arr[arr[:, -1] >= 1]
|
|
|
|
| 118 |
return arr
|
| 119 |
|
| 120 |
|
| 121 |
+
def find_label_runs(arr: np.ndarray) -> List[Tuple[int, np.ndarray]]:
|
| 122 |
+
"""Groups consecutive rows with identical class labels.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
arr (np.ndarray): Data array where the last column is the class label.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List[Tuple[int, np.ndarray]]: A list of tuples (label, sub-array).
|
| 129 |
+
"""
|
| 130 |
runs = []
|
| 131 |
if arr.size == 0:
|
| 132 |
return runs
|
|
|
|
| 141 |
return runs
|
| 142 |
|
| 143 |
|
| 144 |
+
def sliding_window_majority(
|
| 145 |
+
seg_arr: np.ndarray,
|
| 146 |
+
window_size: int = 1000,
|
| 147 |
+
stride: int = 500
|
| 148 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 149 |
+
"""Segments a label-consistent array using a sliding window and majority voting.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
seg_arr (np.ndarray): Data array of shape (T, 10).
|
| 153 |
+
window_size (int, optional): Number of samples per window. Defaults to 1000.
|
| 154 |
+
stride (int, optional): Number of samples to shift between windows. Defaults to 500.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Tuple[np.ndarray, np.ndarray]: A tuple containing:
|
| 158 |
+
- Windowed EMG segments (N, window_size, 8).
|
| 159 |
+
- Majority vote labels (N,).
|
| 160 |
+
"""
|
| 161 |
segs, labs = [], []
|
| 162 |
for start in range(0, len(seg_arr) - window_size + 1, stride):
|
| 163 |
win = seg_arr[start : start + window_size]
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
def users_with_gesture(
|
| 171 |
+
data_root: str,
|
| 172 |
+
gesture_id: int,
|
| 173 |
+
subj_range: range = range(1, 37),
|
| 174 |
+
return_counts: bool = False
|
| 175 |
+
) -> Union[List[int], Dict[int, int]]:
|
| 176 |
+
"""Identifies which subjects performed a specific gesture.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
data_root (str): Root directory of the dataset.
|
| 180 |
+
gesture_id (int): The ID of the gesture to search for.
|
| 181 |
+
subj_range (range, optional): Range of subject IDs to check. Defaults to range(1, 37).
|
| 182 |
+
return_counts (bool, optional): If True, returns a dictionary with sample counts.
|
| 183 |
+
Defaults to False.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Union[List[int], Dict[int, int]]: Either a list of subject IDs or a dictionary
|
| 187 |
+
mapping subject ID to occurrence count.
|
| 188 |
+
"""
|
| 189 |
found = {}
|
| 190 |
for subj in subj_range:
|
| 191 |
subj_dir = os.path.join(data_root, f"{subj:02d}")
|
|
|
|
| 215 |
else:
|
| 216 |
return sorted(found.keys())
|
| 217 |
|
| 218 |
+
def concat_data(lst: List[np.ndarray]) -> np.ndarray:
|
| 219 |
+
"""Concatenates a list of data arrays.
|
| 220 |
|
| 221 |
+
Args:
|
| 222 |
+
lst (List[np.ndarray]): List of arrays to concatenate.
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
np.ndarray: Concatenated array or empty array if list is empty.
|
| 226 |
+
"""
|
| 227 |
return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
|
| 228 |
|
| 229 |
+
def concat_label(lst: List[np.ndarray]) -> np.ndarray:
|
| 230 |
+
"""Concatenates a list of label arrays.
|
| 231 |
|
| 232 |
+
Args:
|
| 233 |
+
lst (List[np.ndarray]): List of label arrays.
|
| 234 |
|
| 235 |
+
Returns:
|
| 236 |
+
np.ndarray: Concatenated array or empty array if list is empty.
|
| 237 |
+
"""
|
| 238 |
+
return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)
|
| 239 |
|
|
|
|
|
|
|
|
|
|
| 240 |
if __name__ == "__main__":
|
| 241 |
import argparse
|
| 242 |
|