| # SleepStageNet - ๅบไบๅฟ็/HRV/ๅผๅธ/ไฝๅจ็็ก็ ๅๆๆจกๅ |
|
|
| ## ๐ฏ ๆฆ่ฟฐ |
|
|
| SleepStageNet ๆฏไธไธชๅบไบๆทฑๅบฆๅญฆไน ็็ก็ ๅๆๆจกๅ๏ผไฝฟ็จ **4ไธช้EEG็นๅพ** ่ฟ่ก่ชๅจ็ก็ ๅๆ๏ผ |
|
|
| | ็นๅพ | ่ฏดๆ | ็็ๆไน | |
| |------|------|---------| |
| | **HRV (RMSSD)** | ๅฟ็ๅๅผๆง | ๅๆ ่ชไธป็ฅ็ป็ถๆ๏ผๅฏไบคๆๆดปๆง๏ผ | |
| | **ๅฟ็ (HR)** | ๆฏๅ้ๅฟ่ทณๆฌกๆฐ | ๅๆ ๆดไฝๅฟ่ก็ฎกๆฐดๅนณ | |
| | **ๅผๅธ้ข็ (RR)** | ๆฏๅ้ๅผๅธๆฌกๆฐ | ๅๆ ๅผๅธ่ฐ่็ถๆ | |
| | **ไฝๅจ (Movement)** | ่บซไฝๆดปๅจ้ | ๅๆ ่ขไฝ่ฟๅจ/่ง้ | |
|
|
| **่พๅบ**๏ผๆฏ30็งepoch็็ก็ ๅๆ โ Wake / N1 / N2 / N3 / REM |
|
|
| ## ๐ ๆถๆ่ฎพ่ฎก |
|
|
| ็ปผๅไบไปฅไธSOTA่ฎบๆ็ๆไฝณ่ฎพ่ฎก๏ผ |
|
|
| ``` |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ |
| โ SleepStageNet โ |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค |
| โ โ |
| โ ่พๅ
ฅ: (batch, T, 4) โ Tไธช30็งepoch, ๆฏไธชๆ4ไธช็นๅพ โ |
| โ โ |
| โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ 1. Feature Projection (ๅ่SleepPPG-Net) โ โ |
| โ โ MLP: 4 โ d_model*2 โ d_model โ โ |
| โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ โ |
| โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ 2. Cross-Feature Attention (ๅ่wav2sleep) โ โ |
| โ โ ๆฏไธช็นๅพ็ฌ็ซๆๅฝฑ + CLS Token + Transformer โ โ |
| โ โ ๅญฆไน HRVโHRโRRโMovement ็ไบคไบๅ
ณ็ณป โ โ |
| โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ (้จๆง่ๅ) โ |
| โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ 3. Positional Encoding โ โ |
| โ โ ๆญฃๅผฆไฝ็ฝฎ็ผ็ (ๆถ้ดไฝ็ฝฎๅฏน็ก็ ็ปๆๅพ้่ฆ) โ โ |
| โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ โ |
| โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ 4. Dilated Temporal CNN (ๅ่wav2sleep) โ โ |
| โ โ 2 blocks ร [d=1,2,4,8,16,32], k=7 โ โ |
| โ โ ๆๅ้ โ 6ๅฐๆถ โ ๆ่ทๅฎๆด็ก็ ๅจๆ โ โ |
| โ โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ โ |
| โ โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ 5. Classification Head โ โ |
| โ โ Linear(d_model โ d_model/2 โ n_classes) โ โ |
| โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ |
| โ โ |
| โ ่พๅบ: (batch, T, n_classes) โ ๆฏไธชepoch็ๅ็ฑปlogits โ |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ |
| ``` |
|
|
| ### ๅ
ณ้ฎๅๆฐ |
|
|
| 1. **ๅ่ทฏๅพ้จๆง่ๅ**๏ผ็ฎๅMLPๆๅฝฑ + ่ทจ็นๅพTransformerๆณจๆๅ็้จๆง่ๅ๏ผๅ
ผ้กพๆ็ๅ่กจ่พพ่ฝๅ |
| 2. **้ๆบ็นๅพ้ฎ่ฝ** (ๅ่wav2sleep)๏ผ่ฎญ็ปๆถไปฅ30%ๆฆ็้ๆบmask็นๅพ๏ผๆ้ซไผ ๆๅจ็ผบๅคฑๆถ็้ฒๆฃๆง |
| 3. **่จ่ๆถๅบCNN** (ๅ่wav2sleep)๏ผๆๅ้่ฆ็โ6ๅฐๆถ๏ผ่ฝๆ่ทๅฎๆด็90-120ๅ้็ก็ ๅจๆ |
| 4. **Per-patient Z-scoreๆ ๅๅ** (ๅ่SleepPPG-Net)๏ผๆถ้คไธชไฝๅบ็บฟๅทฎๅผ๏ผ่ฟๆฏๆๅ
ณ้ฎ็้ขๅค็ๆญฅ้ชค |
|
|
| ## ๐ ๆจกๅ้
็ฝฎ |
|
|
| | ้
็ฝฎ | d_model | ๅๆฐ้ | Epoch Mixer | Sequence Mixer | ้็จๅบๆฏ | |
| |------|---------|--------|-------------|----------------|---------| |
| | small | 64 | ~195K | 1ๅฑ Transformer | 1 block, d=[1,2,4,8,16] | ๅฟซ้ๅฎ้ช | |
| | **base** | 128 | **~2M** | 2ๅฑ Transformer | 2 blocks, d=[1,2,4,8,16,32] | **ๆจ่** | |
| | large | 256 | ~13.7M | 3ๅฑ Transformer | 3 blocks, d=[1,2,4,8,16,32,64] | ๆไฝณๆง่ฝ | |
| |
| ## ๐ฌ ๅ่ๆ็ฎ |
| |
| | ่ฎบๆ | ่ดก็ฎ | ๅนดไปฝ | |
| |------|------|------| |
| | [wav2sleep](https://arxiv.org/abs/2411.04644) | Epoch Mixer + Sequence Mixer + ้ๆบๆจกๆ้ฎ่ฝ | 2024 | |
| | [Cross-Modal Transformer](https://arxiv.org/abs/2208.06991) | ่ทจๆจกๆๆณจๆๅ + ๅ ๆไบคๅ็ต [1,2,1,2,2] | 2022 | |
| | [SleepPPG-Net](https://arxiv.org/abs/2202.05735) | Per-patient Z-score + BiLSTM FEๅบ็บฟ | 2022 | |
| | [Mamba-sleep](https://arxiv.org/abs/2412.15947) | ่ฝป้Mambaๅบๅๅปบๆจก + ้้ข็ๅ ๆ | 2024 | |
| |
| ## ๐ ๅฟซ้ไฝฟ็จ |
| |
| ### ๅฎ่ฃ
ไพ่ต |
| |
| ```bash |
| pip install torch numpy |
| ``` |
| |
| ### ๆจ็็คบไพ |
| |
| ```python |
| import numpy as np |
| import torch |
| from sleep_staging_model import create_model |
|
|
| # ๅๅปบๆจกๅ |
| model = create_model('base', n_features=4, n_classes=4) |
| |
| # ๅๅค่พๅ
ฅๆฐๆฎ: Tไธช30็งepoch็็นๅพ |
| T = 1200 # 10ๅฐๆถ = 1200ไธชepoch |
| features = np.stack([ |
| hrv_rmssd, # HRV (RMSSD) ๅบๅ |
| heart_rate, # ๅฟ็ๅบๅ |
| respiratory_rate, # ๅผๅธ้ข็ๅบๅ |
| body_movement, # ไฝๅจๅบๅ |
| ], axis=-1) # shape: (T, 4) |
| |
| # Z-scoreๆ ๅๅ (ๅ
ณ้ฎๆญฅ้ชค!) |
| features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8) |
| features = np.clip(features, -5, 5) |
|
|
| # ๆจ็ |
| x = torch.tensor(features, dtype=torch.float32).unsqueeze(0) # (1, T, 4) |
| model.eval() |
| with torch.no_grad(): |
| logits = model(x) # (1, T, n_classes) |
| predictions = torch.argmax(logits, dim=-1) # (1, T) |
| |
| # ๆ ็ญพ: 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM |
| stage_names = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'} |
| ``` |
| |
| ### ่ฎญ็ป |
| |
| ```bash |
| python train_sleep_staging.py --model base --batch_size 16 --lr 1e-3 --max_epochs 100 |
| ``` |
| |
| ## ๐ ่ฎญ็ป้
็ฝฎ (ๆฅ่ชๆ็ฎๆไฝณๅฎ่ทต) |
| |
| | ่ถ
ๅๆฐ | ๅผ | ๆฅๆบ | |
| |--------|------|------| |
| | Optimizer | AdamW | wav2sleep | |
| | Learning Rate | 1e-3 | wav2sleep | |
| | Weight Decay | 1e-2 | wav2sleep | |
| | Batch Size | 16 (ๆดๅค) | wav2sleep | |
| | LR Schedule | OneCycleLR (10% warmup + cosine) | ๆน่ฟ่ชwav2sleep | |
| | Loss | Weighted Focal Loss (ฮณ=2) | Cross-Modal Transformer + Focal | |
| | Class Weights | ้้ข็ๅ ๆ | Mamba-sleep | |
| | Early Stopping | patience=10 | wav2sleep (5) | |
| | Gradient Clip | max_norm=1.0 | ๆ ๅๅฎ่ทต | |
| | Augmentation | ้ๆบ็ฟป่ฝฌ(p=0.5) + ๅชๅฃฐ(p=0.3) | wav2sleep | |
| | Feature Mask | p=0.3 | wav2sleep | |
|
|
| ## ๐ ๆฐๆฎ้ |
|
|
| ่ฎญ็ปๆฐๆฎ: [`abmallick/heart-breath-sleep-stage-dataset`](https://huggingface.co/datasets/abmallick/heart-breath-sleep-stage-dataset) |
|
|
| - 30็งepoch็ฒๅบฆ |
| - ๅ
ๅซๅฟ็ใๅผๅธ้ข็ใHRVๆๆ |
| - 485+ๅค็ก็ ่ฎฐๅฝ |
| - 4็ฑปๆ ๆณจ: Wake, N1, N2, N3 |
|
|
| ## ๆๆๆง่ฝ (ๅบไบๆ็ฎ) |
|
|
| | ๆๆ | 4็ฑป (ๆ REM) | 5็ฑป (ๅซREM) | |
| |------|------------|------------| |
| | Cohen's ฮบ | 0.55-0.65 | 0.50-0.60 | |
| | Accuracy | 70-80% | 65-75% | |
| | F1 (macro) | 0.50-0.65 | 0.45-0.60 | |
|
|
| > ๆณจ: ้EEG็นๅพ็็ก็ ๅๆๆง่ฝ้ๅธธไฝไบEEG-basedๆนๆณ(ฮบโ0.75+)๏ผ่ฟๆฏ่ฏฅ้ขๅ็ๅบๆ้ๅถใ |
|
|
| ## License |
|
|
| MIT |
|
|