Add comprehensive README with architecture docs
Browse files
README.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SleepStageNet - ๅบไบๅฟ็/HRV/ๅผๅธ/ไฝๅจ็็ก็ ๅๆๆจกๅ
|
| 2 |
+
|
| 3 |
+
## ๐ฏ ๆฆ่ฟฐ
|
| 4 |
+
|
| 5 |
+
SleepStageNet ๆฏไธไธชๅบไบๆทฑๅบฆๅญฆไน ็็ก็ ๅๆๆจกๅ๏ผไฝฟ็จ **4ไธช้EEG็นๅพ** ่ฟ่ก่ชๅจ็ก็ ๅๆ๏ผ
|
| 6 |
+
|
| 7 |
+
| ็นๅพ | ่ฏดๆ | ็็ๆไน |
|
| 8 |
+
|------|------|---------|
|
| 9 |
+
| **HRV (RMSSD)** | ๅฟ็ๅๅผๆง | ๅๆ ่ชไธป็ฅ็ป็ถๆ๏ผๅฏไบคๆๆดปๆง๏ผ |
|
| 10 |
+
| **ๅฟ็ (HR)** | ๆฏๅ้ๅฟ่ทณๆฌกๆฐ | ๅๆ ๆดไฝๅฟ่ก็ฎกๆฐดๅนณ |
|
| 11 |
+
| **ๅผๅธ้ข็ (RR)** | ๆฏๅ้ๅผๅธๆฌกๆฐ | ๅๆ ๅผๅธ่ฐ่็ถๆ |
|
| 12 |
+
| **ไฝๅจ (Movement)** | ่บซไฝๆดปๅจ้ | ๅๆ ่ขไฝ่ฟๅจ/่ง้ |
|
| 13 |
+
|
| 14 |
+
**่พๅบ**๏ผๆฏ30็งepoch็็ก็ ๅๆ โ Wake / N1 / N2 / N3 / REM
|
| 15 |
+
|
| 16 |
+
## ๐ ๆถๆ่ฎพ่ฎก
|
| 17 |
+
|
| 18 |
+
็ปผๅไบไปฅไธSOTA่ฎบๆ็ๆไฝณ่ฎพ่ฎก๏ผ
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 22 |
+
โ SleepStageNet โ
|
| 23 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
|
| 24 |
+
โ โ
|
| 25 |
+
โ ่พๅ
ฅ: (batch, T, 4) โ Tไธช30็งepoch, ๆฏไธชๆ4ไธช็นๅพ โ
|
| 26 |
+
โ โ
|
| 27 |
+
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 28 |
+
โ โ 1. Feature Projection (ๅ่SleepPPG-Net) โ โ
|
| 29 |
+
โ โ MLP: 4 โ d_model*2 โ d_model โ โ
|
| 30 |
+
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 31 |
+
โ โ โ
|
| 32 |
+
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 33 |
+
โ โ 2. Cross-Feature Attention (ๅ่wav2sleep) โ โ
|
| 34 |
+
โ โ ๆฏไธช็นๅพ็ฌ็ซๆๅฝฑ + CLS Token + Transformer โ โ
|
| 35 |
+
โ โ ๅญฆไน HRVโHRโRRโMovement ็ไบคไบๅ
ณ็ณป โ โ
|
| 36 |
+
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 37 |
+
โ โ (้จๆง่ๅ) โ
|
| 38 |
+
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 39 |
+
โ โ 3. Positional Encoding โ โ
|
| 40 |
+
โ โ ๆญฃๅผฆไฝ็ฝฎ็ผ็ (ๆถ้ดไฝ็ฝฎๅฏน็ก็ ็ปๆๅพ้่ฆ) โ โ
|
| 41 |
+
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 42 |
+
โ โ โ
|
| 43 |
+
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 44 |
+
โ โ 4. Dilated Temporal CNN (ๅ่wav2sleep) โ โ
|
| 45 |
+
โ โ 2 blocks ร [d=1,2,4,8,16,32], k=7 โ โ
|
| 46 |
+
โ โ ๆๅ้ โ 6ๅฐๆถ โ ๆ่ทๅฎๆด็ก็ ๅจๆ โ โ
|
| 47 |
+
โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 48 |
+
โ โ โ
|
| 49 |
+
โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 50 |
+
โ โ 5. Classification Head โ โ
|
| 51 |
+
โ โ Linear(d_model โ d_model/2 โ n_classes) โ โ
|
| 52 |
+
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
|
| 53 |
+
โ โ
|
| 54 |
+
โ ่พๅบ: (batch, T, n_classes) โ ๆฏไธชepoch็ๅ็ฑปlogits โ
|
| 55 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### ๅ
ณ้ฎๅๆฐ
|
| 59 |
+
|
| 60 |
+
1. **ๅ่ทฏๅพ้จๆง่ๅ**๏ผ็ฎๅMLPๆๅฝฑ + ่ทจ็นๅพTransformerๆณจๆๅ็้จๆง่ๅ๏ผๅ
ผ้กพๆ็ๅ่กจ่พพ่ฝๅ
|
| 61 |
+
2. **้ๆบ็นๅพ้ฎ่ฝ** (ๅ่wav2sleep)๏ผ่ฎญ็ปๆถไปฅ30%ๆฆ็้ๆบmask็นๅพ๏ผๆ้ซไผ ๆๅจ็ผบๅคฑๆถ็้ฒๆฃๆง
|
| 62 |
+
3. **่จ่ๆถๅบCNN** (ๅ่wav2sleep)๏ผๆๅ้่ฆ็โ6ๅฐๆถ๏ผ่ฝๆ่ทๅฎๆด็90-120ๅ้็ก็ ๅจๆ
|
| 63 |
+
4. **Per-patient Z-scoreๆ ๅๅ** (ๅ่SleepPPG-Net)๏ผๆถ้คไธชไฝๅบ็บฟๅทฎๅผ๏ผ่ฟๆฏๆๅ
ณ้ฎ็้ขๅค็ๆญฅ้ชค
|
| 64 |
+
|
| 65 |
+
## ๐ ๆจกๅ้
็ฝฎ
|
| 66 |
+
|
| 67 |
+
| ้
็ฝฎ | d_model | ๅๆฐ้ | Epoch Mixer | Sequence Mixer | ้็จๅบๆฏ |
|
| 68 |
+
|------|---------|--------|-------------|----------------|---------|
|
| 69 |
+
| small | 64 | ~195K | 1ๅฑ Transformer | 1 block, d=[1,2,4,8,16] | ๅฟซ้ๅฎ้ช |
|
| 70 |
+
| **base** | 128 | **~2M** | 2ๅฑ Transformer | 2 blocks, d=[1,2,4,8,16,32] | **ๆจ่** |
|
| 71 |
+
| large | 256 | ~13.7M | 3ๅฑ Transformer | 3 blocks, d=[1,2,4,8,16,32,64] | ๆไฝณๆง่ฝ |
|
| 72 |
+
|
| 73 |
+
## ๐ฌ ๅ่ๆ็ฎ
|
| 74 |
+
|
| 75 |
+
| ่ฎบๆ | ่ดก็ฎ | ๅนดไปฝ |
|
| 76 |
+
|------|------|------|
|
| 77 |
+
| [wav2sleep](https://arxiv.org/abs/2411.04644) | Epoch Mixer + Sequence Mixer + ้ๆบๆจกๆ้ฎ่ฝ | 2024 |
|
| 78 |
+
| [Cross-Modal Transformer](https://arxiv.org/abs/2208.06991) | ่ทจๆจกๆๆณจๆๅ + ๅ ๆไบคๅ็ต [1,2,1,2,2] | 2022 |
|
| 79 |
+
| [SleepPPG-Net](https://arxiv.org/abs/2202.05735) | Per-patient Z-score + BiLSTM FEๅบ็บฟ | 2022 |
|
| 80 |
+
| [Mamba-sleep](https://arxiv.org/abs/2412.15947) | ่ฝป้Mambaๅบๅๅปบๆจก + ้้ข็ๅ ๆ | 2024 |
|
| 81 |
+
|
| 82 |
+
## ๐ ๅฟซ้ไฝฟ็จ
|
| 83 |
+
|
| 84 |
+
### ๅฎ่ฃ
ไพ่ต
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
pip install torch numpy
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### ๆจ็็คบไพ
|
| 91 |
+
|
| 92 |
+
```python
|
| 93 |
+
import numpy as np
|
| 94 |
+
import torch
|
| 95 |
+
from sleep_staging_model import create_model
|
| 96 |
+
|
| 97 |
+
# ๅๅปบๆจกๅ
|
| 98 |
+
model = create_model('base', n_features=4, n_classes=4)
|
| 99 |
+
|
| 100 |
+
# ๅๅค่พๅ
ฅๆฐๆฎ: Tไธช30็งepoch็็นๅพ
|
| 101 |
+
T = 1200 # 10ๅฐๆถ = 1200ไธชepoch
|
| 102 |
+
features = np.stack([
|
| 103 |
+
hrv_rmssd, # HRV (RMSSD) ๅบๅ
|
| 104 |
+
heart_rate, # ๅฟ็ๅบๅ
|
| 105 |
+
respiratory_rate, # ๅผๅธ้ข็ๅบๅ
|
| 106 |
+
body_movement, # ไฝๅจๅบๅ
|
| 107 |
+
], axis=-1) # shape: (T, 4)
|
| 108 |
+
|
| 109 |
+
# Z-scoreๆ ๅๅ (ๅ
ณ้ฎๆญฅ้ชค!)
|
| 110 |
+
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
|
| 111 |
+
features = np.clip(features, -5, 5)
|
| 112 |
+
|
| 113 |
+
# ๆจ็
|
| 114 |
+
x = torch.tensor(features, dtype=torch.float32).unsqueeze(0) # (1, T, 4)
|
| 115 |
+
model.eval()
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
logits = model(x) # (1, T, n_classes)
|
| 118 |
+
predictions = torch.argmax(logits, dim=-1) # (1, T)
|
| 119 |
+
|
| 120 |
+
# ๆ ็ญพ: 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM
|
| 121 |
+
stage_names = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'}
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### ่ฎญ็ป
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
python train_sleep_staging.py --model base --batch_size 16 --lr 1e-3 --max_epochs 100
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
## ๐ ่ฎญ็ป้
็ฝฎ (ๆฅ่ชๆ็ฎๆไฝณๅฎ่ทต)
|
| 131 |
+
|
| 132 |
+
| ่ถ
ๅๆฐ | ๅผ | ๆฅๆบ |
|
| 133 |
+
|--------|------|------|
|
| 134 |
+
| Optimizer | AdamW | wav2sleep |
|
| 135 |
+
| Learning Rate | 1e-3 | wav2sleep |
|
| 136 |
+
| Weight Decay | 1e-2 | wav2sleep |
|
| 137 |
+
| Batch Size | 16 (ๆดๅค) | wav2sleep |
|
| 138 |
+
| LR Schedule | OneCycleLR (10% warmup + cosine) | ๆน่ฟ่ชwav2sleep |
|
| 139 |
+
| Loss | Weighted Focal Loss (ฮณ=2) | Cross-Modal Transformer + Focal |
|
| 140 |
+
| Class Weights | ้้ข็ๅ ๆ | Mamba-sleep |
|
| 141 |
+
| Early Stopping | patience=10 | wav2sleep (5) |
|
| 142 |
+
| Gradient Clip | max_norm=1.0 | ๆ ๅๅฎ่ทต |
|
| 143 |
+
| Augmentation | ้ๆบ็ฟป่ฝฌ(p=0.5) + ๅชๅฃฐ(p=0.3) | wav2sleep |
|
| 144 |
+
| Feature Mask | p=0.3 | wav2sleep |
|
| 145 |
+
|
| 146 |
+
## ๐ ๆฐๆฎ้
|
| 147 |
+
|
| 148 |
+
่ฎญ็ปๆฐๆฎ: [`abmallick/heart-breath-sleep-stage-dataset`](https://huggingface.co/datasets/abmallick/heart-breath-sleep-stage-dataset)
|
| 149 |
+
|
| 150 |
+
- 30็งepoch็ฒๅบฆ
|
| 151 |
+
- ๅ
ๅซๅฟ็ใๅผๅธ้ข็ใHRVๆๆ
|
| 152 |
+
- 485+ๅค็ก็ ่ฎฐๅฝ
|
| 153 |
+
- 4็ฑปๆ ๆณจ: Wake, N1, N2, N3
|
| 154 |
+
|
| 155 |
+
## ๆๆๆง่ฝ (ๅบไบๆ็ฎ)
|
| 156 |
+
|
| 157 |
+
| ๆๆ | 4็ฑป (ๆ REM) | 5็ฑป (ๅซREM) |
|
| 158 |
+
|------|------------|------------|
|
| 159 |
+
| Cohen's ฮบ | 0.55-0.65 | 0.50-0.60 |
|
| 160 |
+
| Accuracy | 70-80% | 65-75% |
|
| 161 |
+
| F1 (macro) | 0.50-0.65 | 0.45-0.60 |
|
| 162 |
+
|
| 163 |
+
> ๆณจ: ้EEG็นๅพ็็ก็ ๅๆๆง่ฝ้ๅธธไฝไบEEG-basedๆนๆณ(ฮบโ0.75+)๏ผ่ฟๆฏ่ฏฅ้ขๅ็ๅบๆ้ๅถใ
|
| 164 |
+
|
| 165 |
+
## License
|
| 166 |
+
|
| 167 |
+
MIT
|