sleep-staging-model / sleep_staging_model.py
Liuciba's picture
Add SleepStageNet model architecture
0ab03e7 verified
"""
Sleep Staging Model - ๅŸบไบŽ wav2sleep + Cross-Modal Transformer ็š„ๆททๅˆๆžถๆž„
ๅ‚่€ƒๆ–‡็Œฎ:
1. wav2sleep (2411.04644) - ๅคšๆจกๆ€็ก็œ ๅˆ†ๆœŸSOTA
2. Cross-Modal Transformer (2208.06991) - ่ทจๆจกๆ€ๆณจๆ„ๅŠ›ๆœบๅˆถ
3. SleepPPG-Net (2202.05735) - ็‰นๅพๅทฅ็จ‹ๅˆ†ๆ”ฏBiLSTMๅŸบ็บฟ
4. Mamba-based Sleep Staging (2412.15947) - ่ฝป้‡็บงๅบๅˆ—ๅปบๆจก
่พ“ๅ…ฅ็‰นๅพ: HRV(็ฅž็ป็Šถๆ€), ๅฟƒ็އ(ๆ•ดไฝ“ๆฐดๅนณ), ๅ‘ผๅธ้ข‘็އ, ไฝ“ๅŠจ
่พ“ๅ‡บ: 4/5็ฑป็ก็œ ๅˆ†ๆœŸ (Wake, N1, N2, N3, [REM])
ๆžถๆž„่ฎพ่ฎก (SleepStageNet):
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ 1. Feature Projection Layer (per-epoch) โ”‚
โ”‚ Linear(n_features โ†’ d_model) + LayerNorm + GELU โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ 2. Cross-Feature Attention (Epoch Mixer) โ”‚
โ”‚ Transformer Encoder with CLS token โ”‚
โ”‚ - ่žๅˆHRV/HR/RR/Movement็š„ไบคไบ’ๅ…ณ็ณป โ”‚
โ”‚ - ๅ‚่€ƒwav2sleep็š„Epoch Mixer่ฎพ่ฎก โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ 3. Temporal Context (Sequence Mixer) โ”‚
โ”‚ Dilated Temporal CNN โ”‚
โ”‚ - ๆ•่Žท็ก็œ ๅ‘จๆœŸ็š„้•ฟ็จ‹ๆ—ถๅบไพ่ต– โ”‚
โ”‚ - dilations=[1,2,4,8,16,32], kernel=7 โ”‚
โ”‚ - ๅ‚่€ƒwav2sleep็š„Sequence Mixer โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ 4. Classification Head โ”‚
โ”‚ Linear(d_model โ†’ n_classes) + Softmax โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class FeatureProjection(nn.Module):
"""ๅฐ†ไฝŽ็ปด่พ“ๅ…ฅ็‰นๅพๆŠ•ๅฝฑๅˆฐๆจกๅž‹้š่—็ปดๅบฆ (ๅ‚่€ƒSleepPPG-Net FE branch)"""
def __init__(self, n_features: int = 4, d_model: int = 128, dropout: float = 0.1):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(n_features, d_model * 2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU(), nn.Dropout(dropout),
)
def forward(self, x):
return self.projection(x)
class EfficientCrossFeatureAttention(nn.Module):
"""
้ซ˜ๆ•ˆ่ทจ็‰นๅพๆณจๆ„ๅŠ› (Epoch Mixer)
ๅ‚่€ƒ wav2sleep Epoch Mixer + Cross-Modal Transformer
ๅฐ†ๆฏไธช็‰นๅพ่ง†ไธบ็‹ฌ็ซ‹ๆจกๆ€, ็”จTransformer + CLS token่žๅˆ
"""
def __init__(self, n_features=4, d_model=128, nhead=4, num_layers=2, dim_feedforward=512, dropout=0.1):
super().__init__()
self.n_features = n_features
self.d_model = d_model
self.feature_embeddings = nn.ModuleList([
nn.Sequential(nn.Linear(1, d_model), nn.GELU(), nn.LayerNorm(d_model))
for _ in range(n_features)
])
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
self.feature_type_embedding = nn.Parameter(torch.randn(1, n_features + 1, d_model) * 0.02)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
dropout=dropout, activation='gelu', batch_first=True, norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(d_model))
def forward(self, features):
B, T, F = features.shape
flat = features.reshape(B * T, F)
embedded = torch.cat([self.feature_embeddings[i](flat[:, i:i+1]).unsqueeze(1) for i in range(self.n_features)], dim=1)
cls = self.cls_token.expand(B * T, -1, -1)
tokens = torch.cat([cls, embedded], dim=1) + self.feature_type_embedding
encoded = self.transformer(tokens)
return encoded[:, 0, :].reshape(B, T, self.d_model)
class DilatedResidualBlock(nn.Module):
"""่†จ่ƒ€ๆฎ‹ๅทฎๅท็งฏๅ— (ๅ‚่€ƒwav2sleep Sequence Mixer)"""
def __init__(self, d_model, kernel_size=7, dilation=1, dropout=0.1):
super().__init__()
padding = (kernel_size - 1) * dilation // 2
self.conv = nn.Sequential(
nn.Conv1d(d_model, d_model, kernel_size, padding=padding, dilation=dilation),
nn.GELU(), nn.Dropout(dropout),
nn.Conv1d(d_model, d_model, 1), nn.GELU(), nn.Dropout(dropout),
)
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
residual = x
out = self.conv(x.transpose(1, 2)).transpose(1, 2)
if out.size(1) != residual.size(1):
out = out[:, :residual.size(1), :]
return self.norm(out + residual)
class DilatedTemporalCNN(nn.Module):
"""่†จ่ƒ€ๆ—ถๅบCNN (ๅ‚่€ƒwav2sleep Sequence Mixer, ๆ„Ÿๅ—้‡Žโ‰ˆ6ๅฐๆ—ถ)"""
def __init__(self, d_model=128, kernel_size=7, dilations=None, n_blocks=2, dropout=0.1):
super().__init__()
if dilations is None:
dilations = [1, 2, 4, 8, 16, 32]
self.layers = nn.ModuleList([
DilatedResidualBlock(d_model, kernel_size, d, dropout)
for _ in range(n_blocks) for d in dilations
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=2000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return self.dropout(x + self.pe[:, :x.size(1), :])
class SleepStageNet(nn.Module):
"""
็ก็œ ๅˆ†ๆœŸๆจกๅž‹ - ็ปผๅˆwav2sleep + Cross-Modal Transformer็š„ๆœ€ไฝณ่ฎพ่ฎก
่พ“ๅ…ฅ: (batch, seq_len, 4) - [HRV, HR, RR, Movement] per 30-sec epoch
่พ“ๅ‡บ: (batch, seq_len, n_classes) - ๆฏไธชepoch็š„็ก็œ ๅˆ†ๆœŸlogits
"""
STAGE_NAMES = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'}
def __init__(self, n_features=4, n_classes=5, d_model=128, nhead=4,
epoch_mixer_layers=2, dim_feedforward=512, seq_mixer_blocks=2,
seq_mixer_kernel=7, seq_mixer_dilations=None, max_seq_len=1500,
dropout=0.1, feature_mask_prob=0.3, use_efficient_attention=True):
super().__init__()
self.n_features, self.n_classes, self.d_model = n_features, n_classes, d_model
self.feature_mask_prob = feature_mask_prob
if seq_mixer_dilations is None:
seq_mixer_dilations = [1, 2, 4, 8, 16, 32]
self.simple_projection = FeatureProjection(n_features, d_model, dropout)
self.cross_feature_attn = EfficientCrossFeatureAttention(
n_features, d_model, nhead, epoch_mixer_layers, dim_feedforward, dropout)
self.fusion_gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid())
self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout)
self.seq_mixer = DilatedTemporalCNN(d_model, seq_mixer_kernel, seq_mixer_dilations, seq_mixer_blocks, dropout)
self.classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_model // 2, n_classes))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None: nn.init.zeros_(m.bias)
def _stochastic_feature_mask(self, x):
if self.training and self.feature_mask_prob > 0:
mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob))
while (mask.sum(dim=2) == 0).any():
mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob))
x = x * mask
return x
def forward(self, x, mask=None):
x = self._stochastic_feature_mask(x)
proj = self.simple_projection(x)
attn = self.cross_feature_attn(x)
gate = self.fusion_gate(torch.cat([proj, attn], dim=-1))
fused = gate * proj + (1 - gate) * attn
fused = self.pos_encoding(fused)
temporal = self.seq_mixer(fused)
return self.classifier(temporal)
def predict(self, x):
self.eval()
with torch.no_grad():
return torch.argmax(self.forward(x), dim=-1)
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class WeightedFocalLoss(nn.Module):
"""ๅŠ ๆƒFocal Loss (ๅ‚่€ƒCross-Modal Transformer + Mamba-sleep)"""
def __init__(self, class_weights=None, gamma=2.0, reduction='mean'):
super().__init__()
if class_weights is None:
class_weights = [1.0, 2.0, 1.0, 1.5, 1.5]
self.register_buffer('weight', torch.tensor(class_weights, dtype=torch.float32))
self.gamma, self.reduction = gamma, reduction
def forward(self, logits, targets):
if logits.dim() == 3:
logits, targets = logits.reshape(-1, logits.size(-1)), targets.reshape(-1)
ce = F.cross_entropy(logits, targets, weight=self.weight, reduction='none')
focal = (1 - torch.exp(-ce)) ** self.gamma * ce
return focal.mean() if self.reduction == 'mean' else focal.sum() if self.reduction == 'sum' else focal
class SleepDataProcessor:
@staticmethod
def per_patient_normalize(features, night_ids):
import numpy as np
normalized = features.copy()
for nid in np.unique(night_ids):
mask = night_ids == nid
data = features[mask]
mean, std = data.mean(axis=0), data.std(axis=0)
std[std < 1e-8] = 1.0
normalized[mask] = (data - mean) / std
return normalized
MODEL_CONFIGS = {
'small': dict(d_model=64, nhead=4, epoch_mixer_layers=1, dim_feedforward=256,
seq_mixer_blocks=1, seq_mixer_kernel=5, seq_mixer_dilations=[1,2,4,8,16], dropout=0.1),
'base': dict(d_model=128, nhead=4, epoch_mixer_layers=2, dim_feedforward=512,
seq_mixer_blocks=2, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32], dropout=0.1),
'large': dict(d_model=256, nhead=8, epoch_mixer_layers=3, dim_feedforward=1024,
seq_mixer_blocks=3, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32,64], dropout=0.15),
}
def create_model(config_name='base', n_features=4, n_classes=5, **kwargs):
config = MODEL_CONFIGS[config_name].copy()
config.update(kwargs)
model = SleepStageNet(n_features=n_features, n_classes=n_classes, **config)
print(f"Created SleepStageNet-{config_name} ({model.count_parameters():,} params)")
return model