File size: 8,566 Bytes
5fa7631 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | # SleepStageNet - ๅบไบๅฟ็/HRV/ๅผๅธ/ไฝๅจ็็ก็ ๅๆๆจกๅ
## ๐ฏ ๆฆ่ฟฐ
SleepStageNet ๆฏไธไธชๅบไบๆทฑๅบฆๅญฆไน ็็ก็ ๅๆๆจกๅ๏ผไฝฟ็จ **4ไธช้EEG็นๅพ** ่ฟ่ก่ชๅจ็ก็ ๅๆ๏ผ
| ็นๅพ | ่ฏดๆ | ็็ๆไน |
|------|------|---------|
| **HRV (RMSSD)** | ๅฟ็ๅๅผๆง | ๅๆ ่ชไธป็ฅ็ป็ถๆ๏ผๅฏไบคๆๆดปๆง๏ผ |
| **ๅฟ็ (HR)** | ๆฏๅ้ๅฟ่ทณๆฌกๆฐ | ๅๆ ๆดไฝๅฟ่ก็ฎกๆฐดๅนณ |
| **ๅผๅธ้ข็ (RR)** | ๆฏๅ้ๅผๅธๆฌกๆฐ | ๅๆ ๅผๅธ่ฐ่็ถๆ |
| **ไฝๅจ (Movement)** | ่บซไฝๆดปๅจ้ | ๅๆ ่ขไฝ่ฟๅจ/่ง้ |
**่พๅบ**๏ผๆฏ30็งepoch็็ก็ ๅๆ โ Wake / N1 / N2 / N3 / REM
## ๐ ๆถๆ่ฎพ่ฎก
็ปผๅไบไปฅไธSOTA่ฎบๆ็ๆไฝณ่ฎพ่ฎก๏ผ
```
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ SleepStageNet โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ ่พๅ
ฅ: (batch, T, 4) โ Tไธช30็งepoch, ๆฏไธชๆ4ไธช็นๅพ โ
โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 1. Feature Projection (ๅ่SleepPPG-Net) โ โ
โ โ MLP: 4 โ d_model*2 โ d_model โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 2. Cross-Feature Attention (ๅ่wav2sleep) โ โ
โ โ ๆฏไธช็นๅพ็ฌ็ซๆๅฝฑ + CLS Token + Transformer โ โ
โ โ ๅญฆไน HRVโHRโRRโMovement ็ไบคไบๅ
ณ็ณป โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ (้จๆง่ๅ) โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 3. Positional Encoding โ โ
โ โ ๆญฃๅผฆไฝ็ฝฎ็ผ็ (ๆถ้ดไฝ็ฝฎๅฏน็ก็ ็ปๆๅพ้่ฆ) โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 4. Dilated Temporal CNN (ๅ่wav2sleep) โ โ
โ โ 2 blocks ร [d=1,2,4,8,16,32], k=7 โ โ
โ โ ๆๅ้ โ 6ๅฐๆถ โ ๆ่ทๅฎๆด็ก็ ๅจๆ โ โ
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ 5. Classification Head โ โ
โ โ Linear(d_model โ d_model/2 โ n_classes) โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ
โ ่พๅบ: (batch, T, n_classes) โ ๆฏไธชepoch็ๅ็ฑปlogits โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
```
### ๅ
ณ้ฎๅๆฐ
1. **ๅ่ทฏๅพ้จๆง่ๅ**๏ผ็ฎๅMLPๆๅฝฑ + ่ทจ็นๅพTransformerๆณจๆๅ็้จๆง่ๅ๏ผๅ
ผ้กพๆ็ๅ่กจ่พพ่ฝๅ
2. **้ๆบ็นๅพ้ฎ่ฝ** (ๅ่wav2sleep)๏ผ่ฎญ็ปๆถไปฅ30%ๆฆ็้ๆบmask็นๅพ๏ผๆ้ซไผ ๆๅจ็ผบๅคฑๆถ็้ฒๆฃๆง
3. **่จ่ๆถๅบCNN** (ๅ่wav2sleep)๏ผๆๅ้่ฆ็โ6ๅฐๆถ๏ผ่ฝๆ่ทๅฎๆด็90-120ๅ้็ก็ ๅจๆ
4. **Per-patient Z-scoreๆ ๅๅ** (ๅ่SleepPPG-Net)๏ผๆถ้คไธชไฝๅบ็บฟๅทฎๅผ๏ผ่ฟๆฏๆๅ
ณ้ฎ็้ขๅค็ๆญฅ้ชค
## ๐ ๆจกๅ้
็ฝฎ
| ้
็ฝฎ | d_model | ๅๆฐ้ | Epoch Mixer | Sequence Mixer | ้็จๅบๆฏ |
|------|---------|--------|-------------|----------------|---------|
| small | 64 | ~195K | 1ๅฑ Transformer | 1 block, d=[1,2,4,8,16] | ๅฟซ้ๅฎ้ช |
| **base** | 128 | **~2M** | 2ๅฑ Transformer | 2 blocks, d=[1,2,4,8,16,32] | **ๆจ่** |
| large | 256 | ~13.7M | 3ๅฑ Transformer | 3 blocks, d=[1,2,4,8,16,32,64] | ๆไฝณๆง่ฝ |
## ๐ฌ ๅ่ๆ็ฎ
| ่ฎบๆ | ่ดก็ฎ | ๅนดไปฝ |
|------|------|------|
| [wav2sleep](https://arxiv.org/abs/2411.04644) | Epoch Mixer + Sequence Mixer + ้ๆบๆจกๆ้ฎ่ฝ | 2024 |
| [Cross-Modal Transformer](https://arxiv.org/abs/2208.06991) | ่ทจๆจกๆๆณจๆๅ + ๅ ๆไบคๅ็ต [1,2,1,2,2] | 2022 |
| [SleepPPG-Net](https://arxiv.org/abs/2202.05735) | Per-patient Z-score + BiLSTM FEๅบ็บฟ | 2022 |
| [Mamba-sleep](https://arxiv.org/abs/2412.15947) | ่ฝป้Mambaๅบๅๅปบๆจก + ้้ข็ๅ ๆ | 2024 |
## ๐ ๅฟซ้ไฝฟ็จ
### ๅฎ่ฃ
ไพ่ต
```bash
pip install torch numpy
```
### ๆจ็็คบไพ
```python
import numpy as np
import torch
from sleep_staging_model import create_model
# ๅๅปบๆจกๅ
model = create_model('base', n_features=4, n_classes=4)
# ๅๅค่พๅ
ฅๆฐๆฎ: Tไธช30็งepoch็็นๅพ
T = 1200 # 10ๅฐๆถ = 1200ไธชepoch
features = np.stack([
hrv_rmssd, # HRV (RMSSD) ๅบๅ
heart_rate, # ๅฟ็ๅบๅ
respiratory_rate, # ๅผๅธ้ข็ๅบๅ
body_movement, # ไฝๅจๅบๅ
], axis=-1) # shape: (T, 4)
# Z-scoreๆ ๅๅ (ๅ
ณ้ฎๆญฅ้ชค!)
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
features = np.clip(features, -5, 5)
# ๆจ็
x = torch.tensor(features, dtype=torch.float32).unsqueeze(0) # (1, T, 4)
model.eval()
with torch.no_grad():
logits = model(x) # (1, T, n_classes)
predictions = torch.argmax(logits, dim=-1) # (1, T)
# ๆ ็ญพ: 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM
stage_names = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'}
```
### ่ฎญ็ป
```bash
python train_sleep_staging.py --model base --batch_size 16 --lr 1e-3 --max_epochs 100
```
## ๐ ่ฎญ็ป้
็ฝฎ (ๆฅ่ชๆ็ฎๆไฝณๅฎ่ทต)
| ่ถ
ๅๆฐ | ๅผ | ๆฅๆบ |
|--------|------|------|
| Optimizer | AdamW | wav2sleep |
| Learning Rate | 1e-3 | wav2sleep |
| Weight Decay | 1e-2 | wav2sleep |
| Batch Size | 16 (ๆดๅค) | wav2sleep |
| LR Schedule | OneCycleLR (10% warmup + cosine) | ๆน่ฟ่ชwav2sleep |
| Loss | Weighted Focal Loss (ฮณ=2) | Cross-Modal Transformer + Focal |
| Class Weights | ้้ข็ๅ ๆ | Mamba-sleep |
| Early Stopping | patience=10 | wav2sleep (5) |
| Gradient Clip | max_norm=1.0 | ๆ ๅๅฎ่ทต |
| Augmentation | ้ๆบ็ฟป่ฝฌ(p=0.5) + ๅชๅฃฐ(p=0.3) | wav2sleep |
| Feature Mask | p=0.3 | wav2sleep |
## ๐ ๆฐๆฎ้
่ฎญ็ปๆฐๆฎ: [`abmallick/heart-breath-sleep-stage-dataset`](https://huggingface.co/datasets/abmallick/heart-breath-sleep-stage-dataset)
- 30็งepoch็ฒๅบฆ
- ๅ
ๅซๅฟ็ใๅผๅธ้ข็ใHRVๆๆ
- 485+ๅค็ก็ ่ฎฐๅฝ
- 4็ฑปๆ ๆณจ: Wake, N1, N2, N3
## ๆๆๆง่ฝ (ๅบไบๆ็ฎ)
| ๆๆ | 4็ฑป (ๆ REM) | 5็ฑป (ๅซREM) |
|------|------------|------------|
| Cohen's ฮบ | 0.55-0.65 | 0.50-0.60 |
| Accuracy | 70-80% | 65-75% |
| F1 (macro) | 0.50-0.65 | 0.45-0.60 |
> ๆณจ: ้EEG็นๅพ็็ก็ ๅๆๆง่ฝ้ๅธธไฝไบEEG-basedๆนๆณ(ฮบโ0.75+)๏ผ่ฟๆฏ่ฏฅ้ขๅ็ๅบๆ้ๅถใ
## License
MIT
|