sleep-staging-model / README.md
Liuciba's picture
Add comprehensive README with architecture docs
5fa7631 verified
# SleepStageNet - ๅŸบไบŽๅฟƒ็އ/HRV/ๅ‘ผๅธ/ไฝ“ๅŠจ็š„็ก็œ ๅˆ†ๆœŸๆจกๅž‹
## ๐ŸŽฏ ๆฆ‚่ฟฐ
SleepStageNet ๆ˜ฏไธ€ไธชๅŸบไบŽๆทฑๅบฆๅญฆไน ็š„็ก็œ ๅˆ†ๆœŸๆจกๅž‹๏ผŒไฝฟ็”จ **4ไธช้žEEG็‰นๅพ** ่ฟ›่กŒ่‡ชๅŠจ็ก็œ ๅˆ†ๆœŸ๏ผš
| ็‰นๅพ | ่ฏดๆ˜Ž | ็”Ÿ็†ๆ„ไน‰ |
|------|------|---------|
| **HRV (RMSSD)** | ๅฟƒ็އๅ˜ๅผ‚ๆ€ง | ๅๆ˜ ่‡ชไธป็ฅž็ป็Šถๆ€๏ผˆๅ‰ฏไบคๆ„Ÿๆดปๆ€ง๏ผ‰ |
| **ๅฟƒ็އ (HR)** | ๆฏๅˆ†้’Ÿๅฟƒ่ทณๆฌกๆ•ฐ | ๅๆ˜ ๆ•ดไฝ“ๅฟƒ่ก€็ฎกๆฐดๅนณ |
| **ๅ‘ผๅธ้ข‘็އ (RR)** | ๆฏๅˆ†้’Ÿๅ‘ผๅธๆฌกๆ•ฐ | ๅๆ˜ ๅ‘ผๅธ่ฐƒ่Š‚็Šถๆ€ |
| **ไฝ“ๅŠจ (Movement)** | ่บซไฝ“ๆดปๅŠจ้‡ | ๅๆ˜ ่‚ขไฝ“่ฟๅŠจ/่ง‰้†’ |
**่พ“ๅ‡บ**๏ผšๆฏ30็ง’epoch็š„็ก็œ ๅˆ†ๆœŸ โ†’ Wake / N1 / N2 / N3 / REM
## ๐Ÿ“ ๆžถๆž„่ฎพ่ฎก
็ปผๅˆไบ†ไปฅไธ‹SOTA่ฎบๆ–‡็š„ๆœ€ไฝณ่ฎพ่ฎก๏ผš
```
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ SleepStageNet โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ โ”‚
โ”‚ ่พ“ๅ…ฅ: (batch, T, 4) โ† Tไธช30็ง’epoch, ๆฏไธชๆœ‰4ไธช็‰นๅพ โ”‚
โ”‚ โ”‚
โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
โ”‚ โ”‚ 1. Feature Projection (ๅ‚่€ƒSleepPPG-Net) โ”‚ โ”‚
โ”‚ โ”‚ MLP: 4 โ†’ d_model*2 โ†’ d_model โ”‚ โ”‚
โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
โ”‚ โ”‚ โ”‚
โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
โ”‚ โ”‚ 2. Cross-Feature Attention (ๅ‚่€ƒwav2sleep) โ”‚ โ”‚
โ”‚ โ”‚ ๆฏไธช็‰นๅพ็‹ฌ็ซ‹ๆŠ•ๅฝฑ + CLS Token + Transformer โ”‚ โ”‚
โ”‚ โ”‚ ๅญฆไน  HRVโ†”HRโ†”RRโ†”Movement ็š„ไบคไบ’ๅ…ณ็ณป โ”‚ โ”‚
โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
โ”‚ โ”‚ (้—จๆŽง่žๅˆ) โ”‚
โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
โ”‚ โ”‚ 3. Positional Encoding โ”‚ โ”‚
โ”‚ โ”‚ ๆญฃๅผฆไฝ็ฝฎ็ผ–็  (ๆ—ถ้—ดไฝ็ฝฎๅฏน็ก็œ ็ป“ๆž„ๅพˆ้‡่ฆ) โ”‚ โ”‚
โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
โ”‚ โ”‚ โ”‚
โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
โ”‚ โ”‚ 4. Dilated Temporal CNN (ๅ‚่€ƒwav2sleep) โ”‚ โ”‚
โ”‚ โ”‚ 2 blocks ร— [d=1,2,4,8,16,32], k=7 โ”‚ โ”‚
โ”‚ โ”‚ ๆ„Ÿๅ—้‡Ž โ‰ˆ 6ๅฐๆ—ถ โ†’ ๆ•่ŽทๅฎŒๆ•ด็ก็œ ๅ‘จๆœŸ โ”‚ โ”‚
โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
โ”‚ โ”‚ โ”‚
โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
โ”‚ โ”‚ 5. Classification Head โ”‚ โ”‚
โ”‚ โ”‚ Linear(d_model โ†’ d_model/2 โ†’ n_classes) โ”‚ โ”‚
โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
โ”‚ โ”‚
โ”‚ ่พ“ๅ‡บ: (batch, T, n_classes) โ† ๆฏไธชepoch็š„ๅˆ†็ฑปlogits โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
```
### ๅ…ณ้”ฎๅˆ›ๆ–ฐ
1. **ๅŒ่ทฏๅพ„้—จๆŽง่žๅˆ**๏ผš็ฎ€ๅ•MLPๆŠ•ๅฝฑ + ่ทจ็‰นๅพTransformerๆณจๆ„ๅŠ›็š„้—จๆŽง่žๅˆ๏ผŒๅ…ผ้กพๆ•ˆ็އๅ’Œ่กจ่พพ่ƒฝๅŠ›
2. **้šๆœบ็‰นๅพ้ฎ่”ฝ** (ๅ‚่€ƒwav2sleep)๏ผš่ฎญ็ปƒๆ—ถไปฅ30%ๆฆ‚็އ้šๆœบmask็‰นๅพ๏ผŒๆ้ซ˜ไผ ๆ„Ÿๅ™จ็ผบๅคฑๆ—ถ็š„้ฒๆฃ’ๆ€ง
3. **่†จ่ƒ€ๆ—ถๅบCNN** (ๅ‚่€ƒwav2sleep)๏ผšๆ„Ÿๅ—้‡Ž่ฆ†็›–โ‰ˆ6ๅฐๆ—ถ๏ผŒ่ƒฝๆ•่ŽทๅฎŒๆ•ด็š„90-120ๅˆ†้’Ÿ็ก็œ ๅ‘จๆœŸ
4. **Per-patient Z-scoreๆ ‡ๅ‡†ๅŒ–** (ๅ‚่€ƒSleepPPG-Net)๏ผšๆถˆ้™คไธชไฝ“ๅŸบ็บฟๅทฎๅผ‚๏ผŒ่ฟ™ๆ˜ฏๆœ€ๅ…ณ้”ฎ็š„้ข„ๅค„็†ๆญฅ้ชค
## ๐Ÿ“Š ๆจกๅž‹้…็ฝฎ
| ้…็ฝฎ | d_model | ๅ‚ๆ•ฐ้‡ | Epoch Mixer | Sequence Mixer | ้€‚็”จๅœบๆ™ฏ |
|------|---------|--------|-------------|----------------|---------|
| small | 64 | ~195K | 1ๅฑ‚ Transformer | 1 block, d=[1,2,4,8,16] | ๅฟซ้€Ÿๅฎž้ชŒ |
| **base** | 128 | **~2M** | 2ๅฑ‚ Transformer | 2 blocks, d=[1,2,4,8,16,32] | **ๆŽจ่** |
| large | 256 | ~13.7M | 3ๅฑ‚ Transformer | 3 blocks, d=[1,2,4,8,16,32,64] | ๆœ€ไฝณๆ€ง่ƒฝ |
## ๐Ÿ”ฌ ๅ‚่€ƒๆ–‡็Œฎ
| ่ฎบๆ–‡ | ่ดก็Œฎ | ๅนดไปฝ |
|------|------|------|
| [wav2sleep](https://arxiv.org/abs/2411.04644) | Epoch Mixer + Sequence Mixer + ้šๆœบๆจกๆ€้ฎ่”ฝ | 2024 |
| [Cross-Modal Transformer](https://arxiv.org/abs/2208.06991) | ่ทจๆจกๆ€ๆณจๆ„ๅŠ› + ๅŠ ๆƒไบคๅ‰็†ต [1,2,1,2,2] | 2022 |
| [SleepPPG-Net](https://arxiv.org/abs/2202.05735) | Per-patient Z-score + BiLSTM FEๅŸบ็บฟ | 2022 |
| [Mamba-sleep](https://arxiv.org/abs/2412.15947) | ่ฝป้‡Mambaๅบๅˆ—ๅปบๆจก + ้€†้ข‘็އๅŠ ๆƒ | 2024 |
## ๐Ÿš€ ๅฟซ้€Ÿไฝฟ็”จ
### ๅฎ‰่ฃ…ไพ่ต–
```bash
pip install torch numpy
```
### ๆŽจ็†็คบไพ‹
```python
import numpy as np
import torch
from sleep_staging_model import create_model
# ๅˆ›ๅปบๆจกๅž‹
model = create_model('base', n_features=4, n_classes=4)
# ๅ‡†ๅค‡่พ“ๅ…ฅๆ•ฐๆฎ: Tไธช30็ง’epoch็š„็‰นๅพ
T = 1200 # 10ๅฐๆ—ถ = 1200ไธชepoch
features = np.stack([
hrv_rmssd, # HRV (RMSSD) ๅบๅˆ—
heart_rate, # ๅฟƒ็އๅบๅˆ—
respiratory_rate, # ๅ‘ผๅธ้ข‘็އๅบๅˆ—
body_movement, # ไฝ“ๅŠจๅบๅˆ—
], axis=-1) # shape: (T, 4)
# Z-scoreๆ ‡ๅ‡†ๅŒ– (ๅ…ณ้”ฎๆญฅ้ชค!)
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
features = np.clip(features, -5, 5)
# ๆŽจ็†
x = torch.tensor(features, dtype=torch.float32).unsqueeze(0) # (1, T, 4)
model.eval()
with torch.no_grad():
logits = model(x) # (1, T, n_classes)
predictions = torch.argmax(logits, dim=-1) # (1, T)
# ๆ ‡็ญพ: 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM
stage_names = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'}
```
### ่ฎญ็ปƒ
```bash
python train_sleep_staging.py --model base --batch_size 16 --lr 1e-3 --max_epochs 100
```
## ๐Ÿ“‹ ่ฎญ็ปƒ้…็ฝฎ (ๆฅ่‡ชๆ–‡็Œฎๆœ€ไฝณๅฎž่ทต)
| ่ถ…ๅ‚ๆ•ฐ | ๅ€ผ | ๆฅๆบ |
|--------|------|------|
| Optimizer | AdamW | wav2sleep |
| Learning Rate | 1e-3 | wav2sleep |
| Weight Decay | 1e-2 | wav2sleep |
| Batch Size | 16 (ๆ•ดๅคœ) | wav2sleep |
| LR Schedule | OneCycleLR (10% warmup + cosine) | ๆ”น่ฟ›่‡ชwav2sleep |
| Loss | Weighted Focal Loss (ฮณ=2) | Cross-Modal Transformer + Focal |
| Class Weights | ้€†้ข‘็އๅŠ ๆƒ | Mamba-sleep |
| Early Stopping | patience=10 | wav2sleep (5) |
| Gradient Clip | max_norm=1.0 | ๆ ‡ๅ‡†ๅฎž่ทต |
| Augmentation | ้šๆœบ็ฟป่ฝฌ(p=0.5) + ๅ™ชๅฃฐ(p=0.3) | wav2sleep |
| Feature Mask | p=0.3 | wav2sleep |
## ๐Ÿ“Š ๆ•ฐๆฎ้›†
่ฎญ็ปƒๆ•ฐๆฎ: [`abmallick/heart-breath-sleep-stage-dataset`](https://huggingface.co/datasets/abmallick/heart-breath-sleep-stage-dataset)
- 30็ง’epoch็ฒ’ๅบฆ
- ๅŒ…ๅซๅฟƒ็އใ€ๅ‘ผๅธ้ข‘็އใ€HRVๆŒ‡ๆ ‡
- 485+ๅคœ็ก็œ ่ฎฐๅฝ•
- 4็ฑปๆ ‡ๆณจ: Wake, N1, N2, N3
## ๆœŸๆœ›ๆ€ง่ƒฝ (ๅŸบไบŽๆ–‡็Œฎ)
| ๆŒ‡ๆ ‡ | 4็ฑป (ๆ— REM) | 5็ฑป (ๅซREM) |
|------|------------|------------|
| Cohen's ฮบ | 0.55-0.65 | 0.50-0.60 |
| Accuracy | 70-80% | 65-75% |
| F1 (macro) | 0.50-0.65 | 0.45-0.60 |
> ๆณจ: ้žEEG็‰นๅพ็š„็ก็œ ๅˆ†ๆœŸๆ€ง่ƒฝ้€šๅธธไฝŽไบŽEEG-basedๆ–นๆณ•(ฮบโ‰ˆ0.75+)๏ผŒ่ฟ™ๆ˜ฏ่ฏฅ้ข†ๅŸŸ็š„ๅ›บๆœ‰้™ๅˆถใ€‚
## License
MIT