# SleepStageNet - 基于心率/HRV/呼吸/体动的睡眠分期模型 ## 🎯 概述 SleepStageNet 是一个基于深度学习的睡眠分期模型,使用 **4个非EEG特征** 进行自动睡眠分期: | 特征 | 说明 | 生理意义 | |------|------|---------| | **HRV (RMSSD)** | 心率变异性 | 反映自主神经状态(副交感活性) | | **心率 (HR)** | 每分钟心跳次数 | 反映整体心血管水平 | | **呼吸频率 (RR)** | 每分钟呼吸次数 | 反映呼吸调节状态 | | **体动 (Movement)** | 身体活动量 | 反映肢体运动/觉醒 | **输出**:每30秒epoch的睡眠分期 → Wake / N1 / N2 / N3 / REM ## 📐 架构设计 综合了以下SOTA论文的最佳设计: ``` ┌──────────────────────────────────────────────────────────────┐ │ SleepStageNet │ ├──────────────────────────────────────────────────────────────┤ │ │ │ 输入: (batch, T, 4) ← T个30秒epoch, 每个有4个特征 │ │ │ │ ┌──────────────────────────────────────────────────┐ │ │ │ 1. Feature Projection (参考SleepPPG-Net) │ │ │ │ MLP: 4 → d_model*2 → d_model │ │ │ └──────────────────┬───────────────────────────────┘ │ │ │ │ │ ┌──────────────────┴───────────────────────────────┐ │ │ │ 2. Cross-Feature Attention (参考wav2sleep) │ │ │ │ 每个特征独立投影 + CLS Token + Transformer │ │ │ │ 学习 HRV↔HR↔RR↔Movement 的交互关系 │ │ │ └──────────────────┬───────────────────────────────┘ │ │ │ (门控融合) │ │ ┌──────────────────┴───────────────────────────────┐ │ │ │ 3. Positional Encoding │ │ │ │ 正弦位置编码 (时间位置对睡眠结构很重要) │ │ │ └──────────────────┬───────────────────────────────┘ │ │ │ │ │ ┌──────────────────┴───────────────────────────────┐ │ │ │ 4. Dilated Temporal CNN (参考wav2sleep) │ │ │ │ 2 blocks × [d=1,2,4,8,16,32], k=7 │ │ │ │ 感受野 ≈ 6小时 → 捕获完整睡眠周期 │ │ │ └──────────────────┬───────────────────────────────┘ │ │ │ │ │ ┌──────────────────┴───────────────────────────────┐ │ │ │ 5. Classification Head │ │ │ │ Linear(d_model → d_model/2 → n_classes) │ │ │ └──────────────────────────────────────────────────┘ │ │ │ │ 输出: (batch, T, n_classes) ← 每个epoch的分类logits │ └──────────────────────────────────────────────────────────────┘ ``` ### 关键创新 1. **双路径门控融合**:简单MLP投影 + 跨特征Transformer注意力的门控融合,兼顾效率和表达能力 2. **随机特征遮蔽** (参考wav2sleep):训练时以30%概率随机mask特征,提高传感器缺失时的鲁棒性 3. **膨胀时序CNN** (参考wav2sleep):感受野覆盖≈6小时,能捕获完整的90-120分钟睡眠周期 4. **Per-patient Z-score标准化** (参考SleepPPG-Net):消除个体基线差异,这是最关键的预处理步骤 ## 📊 模型配置 | 配置 | d_model | 参数量 | Epoch Mixer | Sequence Mixer | 适用场景 | |------|---------|--------|-------------|----------------|---------| | small | 64 | ~195K | 1层 Transformer | 1 block, d=[1,2,4,8,16] | 快速实验 | | **base** | 128 | **~2M** | 2层 Transformer | 2 blocks, d=[1,2,4,8,16,32] | **推荐** | | large | 256 | ~13.7M | 3层 Transformer | 3 blocks, d=[1,2,4,8,16,32,64] | 最佳性能 | ## 🔬 参考文献 | 论文 | 贡献 | 年份 | |------|------|------| | [wav2sleep](https://arxiv.org/abs/2411.04644) | Epoch Mixer + Sequence Mixer + 随机模态遮蔽 | 2024 | | [Cross-Modal Transformer](https://arxiv.org/abs/2208.06991) | 跨模态注意力 + 加权交叉熵 [1,2,1,2,2] | 2022 | | [SleepPPG-Net](https://arxiv.org/abs/2202.05735) | Per-patient Z-score + BiLSTM FE基线 | 2022 | | [Mamba-sleep](https://arxiv.org/abs/2412.15947) | 轻量Mamba序列建模 + 逆频率加权 | 2024 | ## 🚀 快速使用 ### 安装依赖 ```bash pip install torch numpy ``` ### 推理示例 ```python import numpy as np import torch from sleep_staging_model import create_model # 创建模型 model = create_model('base', n_features=4, n_classes=4) # 准备输入数据: T个30秒epoch的特征 T = 1200 # 10小时 = 1200个epoch features = np.stack([ hrv_rmssd, # HRV (RMSSD) 序列 heart_rate, # 心率序列 respiratory_rate, # 呼吸频率序列 body_movement, # 体动序列 ], axis=-1) # shape: (T, 4) # Z-score标准化 (关键步骤!) features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8) features = np.clip(features, -5, 5) # 推理 x = torch.tensor(features, dtype=torch.float32).unsqueeze(0) # (1, T, 4) model.eval() with torch.no_grad(): logits = model(x) # (1, T, n_classes) predictions = torch.argmax(logits, dim=-1) # (1, T) # 标签: 0=Wake, 1=N1, 2=N2, 3=N3, 4=REM stage_names = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'} ``` ### 训练 ```bash python train_sleep_staging.py --model base --batch_size 16 --lr 1e-3 --max_epochs 100 ``` ## 📋 训练配置 (来自文献最佳实践) | 超参数 | 值 | 来源 | |--------|------|------| | Optimizer | AdamW | wav2sleep | | Learning Rate | 1e-3 | wav2sleep | | Weight Decay | 1e-2 | wav2sleep | | Batch Size | 16 (整夜) | wav2sleep | | LR Schedule | OneCycleLR (10% warmup + cosine) | 改进自wav2sleep | | Loss | Weighted Focal Loss (γ=2) | Cross-Modal Transformer + Focal | | Class Weights | 逆频率加权 | Mamba-sleep | | Early Stopping | patience=10 | wav2sleep (5) | | Gradient Clip | max_norm=1.0 | 标准实践 | | Augmentation | 随机翻转(p=0.5) + 噪声(p=0.3) | wav2sleep | | Feature Mask | p=0.3 | wav2sleep | ## 📊 数据集 训练数据: [`abmallick/heart-breath-sleep-stage-dataset`](https://huggingface.co/datasets/abmallick/heart-breath-sleep-stage-dataset) - 30秒epoch粒度 - 包含心率、呼吸频率、HRV指标 - 485+夜睡眠记录 - 4类标注: Wake, N1, N2, N3 ## 期望性能 (基于文献) | 指标 | 4类 (无REM) | 5类 (含REM) | |------|------------|------------| | Cohen's κ | 0.55-0.65 | 0.50-0.60 | | Accuracy | 70-80% | 65-75% | | F1 (macro) | 0.50-0.65 | 0.45-0.60 | > 注: 非EEG特征的睡眠分期性能通常低于EEG-based方法(κ≈0.75+),这是该领域的固有限制。 ## License MIT