Liuciba commited on
Commit
0ab03e7
·
verified ·
1 Parent(s): e2fb832

Add SleepStageNet model architecture

Browse files
Files changed (1) hide show
  1. sleep_staging_model.py +246 -0
sleep_staging_model.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sleep Staging Model - 基于 wav2sleep + Cross-Modal Transformer 的混合架构
3
+
4
+ 参考文献:
5
+ 1. wav2sleep (2411.04644) - 多模态睡眠分期SOTA
6
+ 2. Cross-Modal Transformer (2208.06991) - 跨模态注意力机制
7
+ 3. SleepPPG-Net (2202.05735) - 特征工程分支BiLSTM基线
8
+ 4. Mamba-based Sleep Staging (2412.15947) - 轻量级序列建模
9
+
10
+ 输入特征: HRV(神经状态), 心率(整体水平), 呼吸频率, 体动
11
+ 输出: 4/5类睡眠分期 (Wake, N1, N2, N3, [REM])
12
+
13
+ 架构设计 (SleepStageNet):
14
+ ┌─────────────────────────────────────────────────────┐
15
+ │ 1. Feature Projection Layer (per-epoch) │
16
+ │ Linear(n_features → d_model) + LayerNorm + GELU │
17
+ ├─────────────────────────────────────────────────────┤
18
+ │ 2. Cross-Feature Attention (Epoch Mixer) │
19
+ │ Transformer Encoder with CLS token │
20
+ │ - 融合HRV/HR/RR/Movement的交互关系 │
21
+ │ - 参考wav2sleep的Epoch Mixer设计 │
22
+ ├─────────────────────────────────────────────────────┤
23
+ │ 3. Temporal Context (Sequence Mixer) │
24
+ │ Dilated Temporal CNN │
25
+ │ - 捕获睡眠周期的长程时序依赖 │
26
+ │ - dilations=[1,2,4,8,16,32], kernel=7 │
27
+ │ - 参考wav2sleep的Sequence Mixer │
28
+ ├─────────────────────────────────────────────────────┤
29
+ │ 4. Classification Head │
30
+ │ Linear(d_model → n_classes) + Softmax │
31
+ └─────────────────────────────────────────────────────┘
32
+ """
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ import math
38
+ from typing import Optional, Tuple
39
+
40
+
41
+ class FeatureProjection(nn.Module):
42
+ """将低维输入特征投影到模型隐藏维度 (参考SleepPPG-Net FE branch)"""
43
+ def __init__(self, n_features: int = 4, d_model: int = 128, dropout: float = 0.1):
44
+ super().__init__()
45
+ self.projection = nn.Sequential(
46
+ nn.Linear(n_features, d_model * 2), nn.GELU(), nn.Dropout(dropout),
47
+ nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU(), nn.Dropout(dropout),
48
+ )
49
+ def forward(self, x):
50
+ return self.projection(x)
51
+
52
+
53
+ class EfficientCrossFeatureAttention(nn.Module):
54
+ """
55
+ 高效跨特征注意力 (Epoch Mixer)
56
+ 参考 wav2sleep Epoch Mixer + Cross-Modal Transformer
57
+ 将每个特征视为独立模态, 用Transformer + CLS token融合
58
+ """
59
+ def __init__(self, n_features=4, d_model=128, nhead=4, num_layers=2, dim_feedforward=512, dropout=0.1):
60
+ super().__init__()
61
+ self.n_features = n_features
62
+ self.d_model = d_model
63
+ self.feature_embeddings = nn.ModuleList([
64
+ nn.Sequential(nn.Linear(1, d_model), nn.GELU(), nn.LayerNorm(d_model))
65
+ for _ in range(n_features)
66
+ ])
67
+ self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
68
+ self.feature_type_embedding = nn.Parameter(torch.randn(1, n_features + 1, d_model) * 0.02)
69
+ encoder_layer = nn.TransformerEncoderLayer(
70
+ d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
71
+ dropout=dropout, activation='gelu', batch_first=True, norm_first=True,
72
+ )
73
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(d_model))
74
+
75
+ def forward(self, features):
76
+ B, T, F = features.shape
77
+ flat = features.reshape(B * T, F)
78
+ embedded = torch.cat([self.feature_embeddings[i](flat[:, i:i+1]).unsqueeze(1) for i in range(self.n_features)], dim=1)
79
+ cls = self.cls_token.expand(B * T, -1, -1)
80
+ tokens = torch.cat([cls, embedded], dim=1) + self.feature_type_embedding
81
+ encoded = self.transformer(tokens)
82
+ return encoded[:, 0, :].reshape(B, T, self.d_model)
83
+
84
+
85
+ class DilatedResidualBlock(nn.Module):
86
+ """膨胀残差卷积块 (参考wav2sleep Sequence Mixer)"""
87
+ def __init__(self, d_model, kernel_size=7, dilation=1, dropout=0.1):
88
+ super().__init__()
89
+ padding = (kernel_size - 1) * dilation // 2
90
+ self.conv = nn.Sequential(
91
+ nn.Conv1d(d_model, d_model, kernel_size, padding=padding, dilation=dilation),
92
+ nn.GELU(), nn.Dropout(dropout),
93
+ nn.Conv1d(d_model, d_model, 1), nn.GELU(), nn.Dropout(dropout),
94
+ )
95
+ self.norm = nn.LayerNorm(d_model)
96
+
97
+ def forward(self, x):
98
+ residual = x
99
+ out = self.conv(x.transpose(1, 2)).transpose(1, 2)
100
+ if out.size(1) != residual.size(1):
101
+ out = out[:, :residual.size(1), :]
102
+ return self.norm(out + residual)
103
+
104
+
105
+ class DilatedTemporalCNN(nn.Module):
106
+ """膨胀时序CNN (参考wav2sleep Sequence Mixer, 感受野≈6小时)"""
107
+ def __init__(self, d_model=128, kernel_size=7, dilations=None, n_blocks=2, dropout=0.1):
108
+ super().__init__()
109
+ if dilations is None:
110
+ dilations = [1, 2, 4, 8, 16, 32]
111
+ self.layers = nn.ModuleList([
112
+ DilatedResidualBlock(d_model, kernel_size, d, dropout)
113
+ for _ in range(n_blocks) for d in dilations
114
+ ])
115
+ def forward(self, x):
116
+ for layer in self.layers:
117
+ x = layer(x)
118
+ return x
119
+
120
+
121
+ class SinusoidalPositionalEncoding(nn.Module):
122
+ def __init__(self, d_model, max_len=2000, dropout=0.1):
123
+ super().__init__()
124
+ self.dropout = nn.Dropout(p=dropout)
125
+ pe = torch.zeros(max_len, d_model)
126
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
127
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
128
+ pe[:, 0::2] = torch.sin(position * div_term)
129
+ pe[:, 1::2] = torch.cos(position * div_term)
130
+ self.register_buffer('pe', pe.unsqueeze(0))
131
+ def forward(self, x):
132
+ return self.dropout(x + self.pe[:, :x.size(1), :])
133
+
134
+
135
+ class SleepStageNet(nn.Module):
136
+ """
137
+ 睡眠分期模型 - 综合wav2sleep + Cross-Modal Transformer的最佳设计
138
+
139
+ 输入: (batch, seq_len, 4) - [HRV, HR, RR, Movement] per 30-sec epoch
140
+ 输出: (batch, seq_len, n_classes) - 每个epoch的睡眠分期logits
141
+ """
142
+ STAGE_NAMES = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'}
143
+
144
+ def __init__(self, n_features=4, n_classes=5, d_model=128, nhead=4,
145
+ epoch_mixer_layers=2, dim_feedforward=512, seq_mixer_blocks=2,
146
+ seq_mixer_kernel=7, seq_mixer_dilations=None, max_seq_len=1500,
147
+ dropout=0.1, feature_mask_prob=0.3, use_efficient_attention=True):
148
+ super().__init__()
149
+ self.n_features, self.n_classes, self.d_model = n_features, n_classes, d_model
150
+ self.feature_mask_prob = feature_mask_prob
151
+ if seq_mixer_dilations is None:
152
+ seq_mixer_dilations = [1, 2, 4, 8, 16, 32]
153
+
154
+ self.simple_projection = FeatureProjection(n_features, d_model, dropout)
155
+ self.cross_feature_attn = EfficientCrossFeatureAttention(
156
+ n_features, d_model, nhead, epoch_mixer_layers, dim_feedforward, dropout)
157
+ self.fusion_gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid())
158
+ self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
159
+ self.seq_mixer = DilatedTemporalCNN(d_model, seq_mixer_kernel, seq_mixer_dilations, seq_mixer_blocks, dropout)
160
+ self.classifier = nn.Sequential(
161
+ nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout),
162
+ nn.Linear(d_model // 2, n_classes))
163
+ self._init_weights()
164
+
165
+ def _init_weights(self):
166
+ for m in self.modules():
167
+ if isinstance(m, nn.Linear):
168
+ nn.init.xavier_uniform_(m.weight)
169
+ if m.bias is not None: nn.init.zeros_(m.bias)
170
+ elif isinstance(m, nn.Conv1d):
171
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
172
+ if m.bias is not None: nn.init.zeros_(m.bias)
173
+
174
+ def _stochastic_feature_mask(self, x):
175
+ if self.training and self.feature_mask_prob > 0:
176
+ mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob))
177
+ while (mask.sum(dim=2) == 0).any():
178
+ mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob))
179
+ x = x * mask
180
+ return x
181
+
182
+ def forward(self, x, mask=None):
183
+ x = self._stochastic_feature_mask(x)
184
+ proj = self.simple_projection(x)
185
+ attn = self.cross_feature_attn(x)
186
+ gate = self.fusion_gate(torch.cat([proj, attn], dim=-1))
187
+ fused = gate * proj + (1 - gate) * attn
188
+ fused = self.pos_encoding(fused)
189
+ temporal = self.seq_mixer(fused)
190
+ return self.classifier(temporal)
191
+
192
+ def predict(self, x):
193
+ self.eval()
194
+ with torch.no_grad():
195
+ return torch.argmax(self.forward(x), dim=-1)
196
+
197
+ def count_parameters(self):
198
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
199
+
200
+
201
+ class WeightedFocalLoss(nn.Module):
202
+ """加权Focal Loss (参考Cross-Modal Transformer + Mamba-sleep)"""
203
+ def __init__(self, class_weights=None, gamma=2.0, reduction='mean'):
204
+ super().__init__()
205
+ if class_weights is None:
206
+ class_weights = [1.0, 2.0, 1.0, 1.5, 1.5]
207
+ self.register_buffer('weight', torch.tensor(class_weights, dtype=torch.float32))
208
+ self.gamma, self.reduction = gamma, reduction
209
+
210
+ def forward(self, logits, targets):
211
+ if logits.dim() == 3:
212
+ logits, targets = logits.reshape(-1, logits.size(-1)), targets.reshape(-1)
213
+ ce = F.cross_entropy(logits, targets, weight=self.weight, reduction='none')
214
+ focal = (1 - torch.exp(-ce)) ** self.gamma * ce
215
+ return focal.mean() if self.reduction == 'mean' else focal.sum() if self.reduction == 'sum' else focal
216
+
217
+
218
+ class SleepDataProcessor:
219
+ @staticmethod
220
+ def per_patient_normalize(features, night_ids):
221
+ import numpy as np
222
+ normalized = features.copy()
223
+ for nid in np.unique(night_ids):
224
+ mask = night_ids == nid
225
+ data = features[mask]
226
+ mean, std = data.mean(axis=0), data.std(axis=0)
227
+ std[std < 1e-8] = 1.0
228
+ normalized[mask] = (data - mean) / std
229
+ return normalized
230
+
231
+
232
+ MODEL_CONFIGS = {
233
+ 'small': dict(d_model=64, nhead=4, epoch_mixer_layers=1, dim_feedforward=256,
234
+ seq_mixer_blocks=1, seq_mixer_kernel=5, seq_mixer_dilations=[1,2,4,8,16], dropout=0.1),
235
+ 'base': dict(d_model=128, nhead=4, epoch_mixer_layers=2, dim_feedforward=512,
236
+ seq_mixer_blocks=2, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32], dropout=0.1),
237
+ 'large': dict(d_model=256, nhead=8, epoch_mixer_layers=3, dim_feedforward=1024,
238
+ seq_mixer_blocks=3, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32,64], dropout=0.15),
239
+ }
240
+
241
+ def create_model(config_name='base', n_features=4, n_classes=5, **kwargs):
242
+ config = MODEL_CONFIGS[config_name].copy()
243
+ config.update(kwargs)
244
+ model = SleepStageNet(n_features=n_features, n_classes=n_classes, **config)
245
+ print(f"Created SleepStageNet-{config_name} ({model.count_parameters():,} params)")
246
+ return model