""" Sleep Staging Model - 基于 wav2sleep + Cross-Modal Transformer 的混合架构 参考文献: 1. wav2sleep (2411.04644) - 多模态睡眠分期SOTA 2. Cross-Modal Transformer (2208.06991) - 跨模态注意力机制 3. SleepPPG-Net (2202.05735) - 特征工程分支BiLSTM基线 4. Mamba-based Sleep Staging (2412.15947) - 轻量级序列建模 输入特征: HRV(神经状态), 心率(整体水平), 呼吸频率, 体动 输出: 4/5类睡眠分期 (Wake, N1, N2, N3, [REM]) 架构设计 (SleepStageNet): ┌─────────────────────────────────────────────────────┐ │ 1. Feature Projection Layer (per-epoch) │ │ Linear(n_features → d_model) + LayerNorm + GELU │ ├─────────────────────────────────────────────────────┤ │ 2. Cross-Feature Attention (Epoch Mixer) │ │ Transformer Encoder with CLS token │ │ - 融合HRV/HR/RR/Movement的交互关系 │ │ - 参考wav2sleep的Epoch Mixer设计 │ ├─────────────────────────────────────────────────────┤ │ 3. Temporal Context (Sequence Mixer) │ │ Dilated Temporal CNN │ │ - 捕获睡眠周期的长程时序依赖 │ │ - dilations=[1,2,4,8,16,32], kernel=7 │ │ - 参考wav2sleep的Sequence Mixer │ ├─────────────────────────────────────────────────────┤ │ 4. Classification Head │ │ Linear(d_model → n_classes) + Softmax │ └─────────────────────────────────────────────────────┘ """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple class FeatureProjection(nn.Module): """将低维输入特征投影到模型隐藏维度 (参考SleepPPG-Net FE branch)""" def __init__(self, n_features: int = 4, d_model: int = 128, dropout: float = 0.1): super().__init__() self.projection = nn.Sequential( nn.Linear(n_features, d_model * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU(), nn.Dropout(dropout), ) def forward(self, x): return self.projection(x) class EfficientCrossFeatureAttention(nn.Module): """ 高效跨特征注意力 (Epoch Mixer) 参考 wav2sleep Epoch Mixer + Cross-Modal Transformer 将每个特征视为独立模态, 用Transformer + CLS token融合 """ def __init__(self, n_features=4, d_model=128, nhead=4, num_layers=2, dim_feedforward=512, dropout=0.1): super().__init__() self.n_features = n_features self.d_model = d_model self.feature_embeddings = nn.ModuleList([ nn.Sequential(nn.Linear(1, d_model), nn.GELU(), nn.LayerNorm(d_model)) for _ in range(n_features) ]) self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) self.feature_type_embedding = nn.Parameter(torch.randn(1, n_features + 1, d_model) * 0.02) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation='gelu', batch_first=True, norm_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(d_model)) def forward(self, features): B, T, F = features.shape flat = features.reshape(B * T, F) embedded = torch.cat([self.feature_embeddings[i](flat[:, i:i+1]).unsqueeze(1) for i in range(self.n_features)], dim=1) cls = self.cls_token.expand(B * T, -1, -1) tokens = torch.cat([cls, embedded], dim=1) + self.feature_type_embedding encoded = self.transformer(tokens) return encoded[:, 0, :].reshape(B, T, self.d_model) class DilatedResidualBlock(nn.Module): """膨胀残差卷积块 (参考wav2sleep Sequence Mixer)""" def __init__(self, d_model, kernel_size=7, dilation=1, dropout=0.1): super().__init__() padding = (kernel_size - 1) * dilation // 2 self.conv = nn.Sequential( nn.Conv1d(d_model, d_model, kernel_size, padding=padding, dilation=dilation), nn.GELU(), nn.Dropout(dropout), nn.Conv1d(d_model, d_model, 1), nn.GELU(), nn.Dropout(dropout), ) self.norm = nn.LayerNorm(d_model) def forward(self, x): residual = x out = self.conv(x.transpose(1, 2)).transpose(1, 2) if out.size(1) != residual.size(1): out = out[:, :residual.size(1), :] return self.norm(out + residual) class DilatedTemporalCNN(nn.Module): """膨胀时序CNN (参考wav2sleep Sequence Mixer, 感受野≈6小时)""" def __init__(self, d_model=128, kernel_size=7, dilations=None, n_blocks=2, dropout=0.1): super().__init__() if dilations is None: dilations = [1, 2, 4, 8, 16, 32] self.layers = nn.ModuleList([ DilatedResidualBlock(d_model, kernel_size, d, dropout) for _ in range(n_blocks) for d in dilations ]) def forward(self, x): for layer in self.layers: x = layer(x) return x class SinusoidalPositionalEncoding(nn.Module): def __init__(self, d_model, max_len=2000, dropout=0.1): super().__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return self.dropout(x + self.pe[:, :x.size(1), :]) class SleepStageNet(nn.Module): """ 睡眠分期模型 - 综合wav2sleep + Cross-Modal Transformer的最佳设计 输入: (batch, seq_len, 4) - [HRV, HR, RR, Movement] per 30-sec epoch 输出: (batch, seq_len, n_classes) - 每个epoch的睡眠分期logits """ STAGE_NAMES = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'} def __init__(self, n_features=4, n_classes=5, d_model=128, nhead=4, epoch_mixer_layers=2, dim_feedforward=512, seq_mixer_blocks=2, seq_mixer_kernel=7, seq_mixer_dilations=None, max_seq_len=1500, dropout=0.1, feature_mask_prob=0.3, use_efficient_attention=True): super().__init__() self.n_features, self.n_classes, self.d_model = n_features, n_classes, d_model self.feature_mask_prob = feature_mask_prob if seq_mixer_dilations is None: seq_mixer_dilations = [1, 2, 4, 8, 16, 32] self.simple_projection = FeatureProjection(n_features, d_model, dropout) self.cross_feature_attn = EfficientCrossFeatureAttention( n_features, d_model, nhead, epoch_mixer_layers, dim_feedforward, dropout) self.fusion_gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid()) self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout) self.seq_mixer = DilatedTemporalCNN(d_model, seq_mixer_kernel, seq_mixer_dilations, seq_mixer_blocks, dropout) self.classifier = nn.Sequential( nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model // 2, n_classes)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) def _stochastic_feature_mask(self, x): if self.training and self.feature_mask_prob > 0: mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob)) while (mask.sum(dim=2) == 0).any(): mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob)) x = x * mask return x def forward(self, x, mask=None): x = self._stochastic_feature_mask(x) proj = self.simple_projection(x) attn = self.cross_feature_attn(x) gate = self.fusion_gate(torch.cat([proj, attn], dim=-1)) fused = gate * proj + (1 - gate) * attn fused = self.pos_encoding(fused) temporal = self.seq_mixer(fused) return self.classifier(temporal) def predict(self, x): self.eval() with torch.no_grad(): return torch.argmax(self.forward(x), dim=-1) def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) class WeightedFocalLoss(nn.Module): """加权Focal Loss (参考Cross-Modal Transformer + Mamba-sleep)""" def __init__(self, class_weights=None, gamma=2.0, reduction='mean'): super().__init__() if class_weights is None: class_weights = [1.0, 2.0, 1.0, 1.5, 1.5] self.register_buffer('weight', torch.tensor(class_weights, dtype=torch.float32)) self.gamma, self.reduction = gamma, reduction def forward(self, logits, targets): if logits.dim() == 3: logits, targets = logits.reshape(-1, logits.size(-1)), targets.reshape(-1) ce = F.cross_entropy(logits, targets, weight=self.weight, reduction='none') focal = (1 - torch.exp(-ce)) ** self.gamma * ce return focal.mean() if self.reduction == 'mean' else focal.sum() if self.reduction == 'sum' else focal class SleepDataProcessor: @staticmethod def per_patient_normalize(features, night_ids): import numpy as np normalized = features.copy() for nid in np.unique(night_ids): mask = night_ids == nid data = features[mask] mean, std = data.mean(axis=0), data.std(axis=0) std[std < 1e-8] = 1.0 normalized[mask] = (data - mean) / std return normalized MODEL_CONFIGS = { 'small': dict(d_model=64, nhead=4, epoch_mixer_layers=1, dim_feedforward=256, seq_mixer_blocks=1, seq_mixer_kernel=5, seq_mixer_dilations=[1,2,4,8,16], dropout=0.1), 'base': dict(d_model=128, nhead=4, epoch_mixer_layers=2, dim_feedforward=512, seq_mixer_blocks=2, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32], dropout=0.1), 'large': dict(d_model=256, nhead=8, epoch_mixer_layers=3, dim_feedforward=1024, seq_mixer_blocks=3, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32,64], dropout=0.15), } def create_model(config_name='base', n_features=4, n_classes=5, **kwargs): config = MODEL_CONFIGS[config_name].copy() config.update(kwargs) model = SleepStageNet(n_features=n_features, n_classes=n_classes, **config) print(f"Created SleepStageNet-{config_name} ({model.count_parameters():,} params)") return model