| """ |
| Sleep Staging Model - ๅบไบ wav2sleep + Cross-Modal Transformer ็ๆททๅๆถๆ |
| |
| ๅ่ๆ็ฎ: |
| 1. wav2sleep (2411.04644) - ๅคๆจกๆ็ก็ ๅๆSOTA |
| 2. Cross-Modal Transformer (2208.06991) - ่ทจๆจกๆๆณจๆๅๆบๅถ |
| 3. SleepPPG-Net (2202.05735) - ็นๅพๅทฅ็จๅๆฏBiLSTMๅบ็บฟ |
| 4. Mamba-based Sleep Staging (2412.15947) - ่ฝป้็บงๅบๅๅปบๆจก |
| |
| ่พๅ
ฅ็นๅพ: HRV(็ฅ็ป็ถๆ), ๅฟ็(ๆดไฝๆฐดๅนณ), ๅผๅธ้ข็, ไฝๅจ |
| ่พๅบ: 4/5็ฑป็ก็ ๅๆ (Wake, N1, N2, N3, [REM]) |
| |
| ๆถๆ่ฎพ่ฎก (SleepStageNet): |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ |
| โ 1. Feature Projection Layer (per-epoch) โ |
| โ Linear(n_features โ d_model) + LayerNorm + GELU โ |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค |
| โ 2. Cross-Feature Attention (Epoch Mixer) โ |
| โ Transformer Encoder with CLS token โ |
| โ - ่ๅHRV/HR/RR/Movement็ไบคไบๅ
ณ็ณป โ |
| โ - ๅ่wav2sleep็Epoch Mixer่ฎพ่ฎก โ |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค |
| โ 3. Temporal Context (Sequence Mixer) โ |
| โ Dilated Temporal CNN โ |
| โ - ๆ่ท็ก็ ๅจๆ็้ฟ็จๆถๅบไพ่ต โ |
| โ - dilations=[1,2,4,8,16,32], kernel=7 โ |
| โ - ๅ่wav2sleep็Sequence Mixer โ |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค |
| โ 4. Classification Head โ |
| โ Linear(d_model โ n_classes) + Softmax โ |
| โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Optional, Tuple |
|
|
|
|
| class FeatureProjection(nn.Module): |
| """ๅฐไฝ็ปด่พๅ
ฅ็นๅพๆๅฝฑๅฐๆจกๅ้่็ปดๅบฆ (ๅ่SleepPPG-Net FE branch)""" |
| def __init__(self, n_features: int = 4, d_model: int = 128, dropout: float = 0.1): |
| super().__init__() |
| self.projection = nn.Sequential( |
| nn.Linear(n_features, d_model * 2), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_model * 2, d_model), nn.LayerNorm(d_model), nn.GELU(), nn.Dropout(dropout), |
| ) |
| def forward(self, x): |
| return self.projection(x) |
|
|
|
|
| class EfficientCrossFeatureAttention(nn.Module): |
| """ |
| ้ซๆ่ทจ็นๅพๆณจๆๅ (Epoch Mixer) |
| ๅ่ wav2sleep Epoch Mixer + Cross-Modal Transformer |
| ๅฐๆฏไธช็นๅพ่งไธบ็ฌ็ซๆจกๆ, ็จTransformer + CLS token่ๅ |
| """ |
| def __init__(self, n_features=4, d_model=128, nhead=4, num_layers=2, dim_feedforward=512, dropout=0.1): |
| super().__init__() |
| self.n_features = n_features |
| self.d_model = d_model |
| self.feature_embeddings = nn.ModuleList([ |
| nn.Sequential(nn.Linear(1, d_model), nn.GELU(), nn.LayerNorm(d_model)) |
| for _ in range(n_features) |
| ]) |
| self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) |
| self.feature_type_embedding = nn.Parameter(torch.randn(1, n_features + 1, d_model) * 0.02) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, |
| dropout=dropout, activation='gelu', batch_first=True, norm_first=True, |
| ) |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(d_model)) |
|
|
| def forward(self, features): |
| B, T, F = features.shape |
| flat = features.reshape(B * T, F) |
| embedded = torch.cat([self.feature_embeddings[i](flat[:, i:i+1]).unsqueeze(1) for i in range(self.n_features)], dim=1) |
| cls = self.cls_token.expand(B * T, -1, -1) |
| tokens = torch.cat([cls, embedded], dim=1) + self.feature_type_embedding |
| encoded = self.transformer(tokens) |
| return encoded[:, 0, :].reshape(B, T, self.d_model) |
|
|
|
|
| class DilatedResidualBlock(nn.Module): |
| """่จ่ๆฎๅทฎๅท็งฏๅ (ๅ่wav2sleep Sequence Mixer)""" |
| def __init__(self, d_model, kernel_size=7, dilation=1, dropout=0.1): |
| super().__init__() |
| padding = (kernel_size - 1) * dilation // 2 |
| self.conv = nn.Sequential( |
| nn.Conv1d(d_model, d_model, kernel_size, padding=padding, dilation=dilation), |
| nn.GELU(), nn.Dropout(dropout), |
| nn.Conv1d(d_model, d_model, 1), nn.GELU(), nn.Dropout(dropout), |
| ) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, x): |
| residual = x |
| out = self.conv(x.transpose(1, 2)).transpose(1, 2) |
| if out.size(1) != residual.size(1): |
| out = out[:, :residual.size(1), :] |
| return self.norm(out + residual) |
|
|
|
|
| class DilatedTemporalCNN(nn.Module): |
| """่จ่ๆถๅบCNN (ๅ่wav2sleep Sequence Mixer, ๆๅ้โ6ๅฐๆถ)""" |
| def __init__(self, d_model=128, kernel_size=7, dilations=None, n_blocks=2, dropout=0.1): |
| super().__init__() |
| if dilations is None: |
| dilations = [1, 2, 4, 8, 16, 32] |
| self.layers = nn.ModuleList([ |
| DilatedResidualBlock(d_model, kernel_size, d, dropout) |
| for _ in range(n_blocks) for d in dilations |
| ]) |
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
|
|
| class SinusoidalPositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=2000, dropout=0.1): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| self.register_buffer('pe', pe.unsqueeze(0)) |
| def forward(self, x): |
| return self.dropout(x + self.pe[:, :x.size(1), :]) |
|
|
|
|
| class SleepStageNet(nn.Module): |
| """ |
| ็ก็ ๅๆๆจกๅ - ็ปผๅwav2sleep + Cross-Modal Transformer็ๆไฝณ่ฎพ่ฎก |
| |
| ่พๅ
ฅ: (batch, seq_len, 4) - [HRV, HR, RR, Movement] per 30-sec epoch |
| ่พๅบ: (batch, seq_len, n_classes) - ๆฏไธชepoch็็ก็ ๅๆlogits |
| """ |
| STAGE_NAMES = {0: 'Wake', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM'} |
|
|
| def __init__(self, n_features=4, n_classes=5, d_model=128, nhead=4, |
| epoch_mixer_layers=2, dim_feedforward=512, seq_mixer_blocks=2, |
| seq_mixer_kernel=7, seq_mixer_dilations=None, max_seq_len=1500, |
| dropout=0.1, feature_mask_prob=0.3, use_efficient_attention=True): |
| super().__init__() |
| self.n_features, self.n_classes, self.d_model = n_features, n_classes, d_model |
| self.feature_mask_prob = feature_mask_prob |
| if seq_mixer_dilations is None: |
| seq_mixer_dilations = [1, 2, 4, 8, 16, 32] |
|
|
| self.simple_projection = FeatureProjection(n_features, d_model, dropout) |
| self.cross_feature_attn = EfficientCrossFeatureAttention( |
| n_features, d_model, nhead, epoch_mixer_layers, dim_feedforward, dropout) |
| self.fusion_gate = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.Sigmoid()) |
| self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len, dropout) |
| self.seq_mixer = DilatedTemporalCNN(d_model, seq_mixer_kernel, seq_mixer_dilations, seq_mixer_blocks, dropout) |
| self.classifier = nn.Sequential( |
| nn.Linear(d_model, d_model // 2), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_model // 2, n_classes)) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv1d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: nn.init.zeros_(m.bias) |
|
|
| def _stochastic_feature_mask(self, x): |
| if self.training and self.feature_mask_prob > 0: |
| mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob)) |
| while (mask.sum(dim=2) == 0).any(): |
| mask = torch.bernoulli(torch.ones(x.shape[0], 1, x.shape[2], device=x.device) * (1 - self.feature_mask_prob)) |
| x = x * mask |
| return x |
|
|
| def forward(self, x, mask=None): |
| x = self._stochastic_feature_mask(x) |
| proj = self.simple_projection(x) |
| attn = self.cross_feature_attn(x) |
| gate = self.fusion_gate(torch.cat([proj, attn], dim=-1)) |
| fused = gate * proj + (1 - gate) * attn |
| fused = self.pos_encoding(fused) |
| temporal = self.seq_mixer(fused) |
| return self.classifier(temporal) |
|
|
| def predict(self, x): |
| self.eval() |
| with torch.no_grad(): |
| return torch.argmax(self.forward(x), dim=-1) |
|
|
| def count_parameters(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| class WeightedFocalLoss(nn.Module): |
| """ๅ ๆFocal Loss (ๅ่Cross-Modal Transformer + Mamba-sleep)""" |
| def __init__(self, class_weights=None, gamma=2.0, reduction='mean'): |
| super().__init__() |
| if class_weights is None: |
| class_weights = [1.0, 2.0, 1.0, 1.5, 1.5] |
| self.register_buffer('weight', torch.tensor(class_weights, dtype=torch.float32)) |
| self.gamma, self.reduction = gamma, reduction |
|
|
| def forward(self, logits, targets): |
| if logits.dim() == 3: |
| logits, targets = logits.reshape(-1, logits.size(-1)), targets.reshape(-1) |
| ce = F.cross_entropy(logits, targets, weight=self.weight, reduction='none') |
| focal = (1 - torch.exp(-ce)) ** self.gamma * ce |
| return focal.mean() if self.reduction == 'mean' else focal.sum() if self.reduction == 'sum' else focal |
|
|
|
|
| class SleepDataProcessor: |
| @staticmethod |
| def per_patient_normalize(features, night_ids): |
| import numpy as np |
| normalized = features.copy() |
| for nid in np.unique(night_ids): |
| mask = night_ids == nid |
| data = features[mask] |
| mean, std = data.mean(axis=0), data.std(axis=0) |
| std[std < 1e-8] = 1.0 |
| normalized[mask] = (data - mean) / std |
| return normalized |
|
|
|
|
| MODEL_CONFIGS = { |
| 'small': dict(d_model=64, nhead=4, epoch_mixer_layers=1, dim_feedforward=256, |
| seq_mixer_blocks=1, seq_mixer_kernel=5, seq_mixer_dilations=[1,2,4,8,16], dropout=0.1), |
| 'base': dict(d_model=128, nhead=4, epoch_mixer_layers=2, dim_feedforward=512, |
| seq_mixer_blocks=2, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32], dropout=0.1), |
| 'large': dict(d_model=256, nhead=8, epoch_mixer_layers=3, dim_feedforward=1024, |
| seq_mixer_blocks=3, seq_mixer_kernel=7, seq_mixer_dilations=[1,2,4,8,16,32,64], dropout=0.15), |
| } |
|
|
| def create_model(config_name='base', n_features=4, n_classes=5, **kwargs): |
| config = MODEL_CONFIGS[config_name].copy() |
| config.update(kwargs) |
| model = SleepStageNet(n_features=n_features, n_classes=n_classes, **config) |
| print(f"Created SleepStageNet-{config_name} ({model.count_parameters():,} params)") |
| return model |
|
|