| """ |
| 睡眠分期模型完整训练脚本 |
| |
| 基于以下SOTA论文的最佳实践: |
| 1. wav2sleep (2411.04644) - 多模态睡眠分期SOTA, AdamW + Linear Warmup + Exp Decay |
| 2. Cross-Modal Transformer (2208.06991) - 跨模态注意力, 加权交叉熵 |
| 3. SleepPPG-Net (2202.05735) - Per-patient Z-score标准化 |
| 4. Mamba-sleep (2412.15947) - 类别频率加权 |
| |
| 数据集: abmallick/heart-breath-sleep-stage-dataset (HuggingFace Hub) |
| - 包含: heart_rate, respiratory_rate, HRV指标(hr_sdnn_5, hr_rmssd_5), 派生特征 |
| - 30秒epoch, 按night_id组织 |
| - 注: 缺少体动数据,使用HR变化率作为替代活动指标 |
| |
| 训练配置: |
| - optimizer: AdamW (lr=1e-3, weight_decay=1e-2) [wav2sleep] |
| - scheduler: CosineAnnealing with warmup [改进自wav2sleep的exp decay] |
| - batch_size: 16 (整夜数据) [wav2sleep] |
| - early stopping: patience=10 epochs [wav2sleep: 5, 我们适当放宽] |
| - loss: Weighted Focal Loss [Cross-Modal Transformer权重 + Focal Loss] |
| - augmentation: 随机特征翻转(p=0.5), 随机特征遮蔽(p=0.3) [wav2sleep] |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import random |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR |
| from collections import Counter |
| from sklearn.metrics import ( |
| accuracy_score, f1_score, cohen_kappa_score, |
| classification_report, confusion_matrix |
| ) |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from sleep_staging_model import ( |
| SleepStageNet, WeightedFocalLoss, SleepDataProcessor, |
| create_model, MODEL_CONFIGS |
| ) |
|
|
| |
| try: |
| import trackio |
| HAS_TRACKIO = True |
| except ImportError: |
| HAS_TRACKIO = False |
|
|
| STAGE_NAMES = ['Wake', 'N1', 'N2', 'N3', 'REM'] |
|
|
|
|
| |
| |
| |
| class SleepNightDataset(Dataset): |
| """ |
| 以整夜为单位的睡眠数据集。 |
| |
| 每个样本是一整夜的特征序列和对应的睡眠分期标签。 |
| 参考 wav2sleep: "padding or truncating each recording to 10h (T=1200)" |
| """ |
| def __init__( |
| self, |
| features: np.ndarray, |
| labels: np.ndarray, |
| night_ids: np.ndarray, |
| max_seq_len: int = 1200, |
| augment: bool = False, |
| ): |
| self.max_seq_len = max_seq_len |
| self.augment = augment |
| |
| |
| self.nights = [] |
| unique_nights = np.unique(night_ids) |
| |
| for nid in unique_nights: |
| mask = night_ids == nid |
| night_features = features[mask] |
| night_labels = labels[mask] |
| |
| if len(night_features) < 10: |
| continue |
| |
| self.nights.append({ |
| 'features': night_features, |
| 'labels': night_labels, |
| 'night_id': nid, |
| 'length': len(night_features), |
| }) |
| |
| print(f" Dataset: {len(self.nights)} nights, " |
| f"avg length: {np.mean([n['length'] for n in self.nights]):.0f} epochs") |
| |
| def __len__(self): |
| return len(self.nights) |
| |
| def __getitem__(self, idx): |
| night = self.nights[idx] |
| features = night['features'].copy() |
| labels = night['labels'].copy() |
| length = night['length'] |
| |
| |
| if self.augment: |
| |
| if random.random() < 0.5: |
| |
| flip_mask = np.random.random(features.shape[1]) > 0.5 |
| features[:, flip_mask] = -features[:, flip_mask] |
| |
| |
| if random.random() < 0.3: |
| noise = np.random.normal(0, 0.05, features.shape) |
| features = features + noise |
| |
| |
| if length >= self.max_seq_len: |
| features = features[:self.max_seq_len] |
| labels = labels[:self.max_seq_len] |
| actual_len = self.max_seq_len |
| else: |
| pad_len = self.max_seq_len - length |
| features = np.pad(features, ((0, pad_len), (0, 0)), 'constant') |
| labels = np.pad(labels, (0, pad_len), 'constant', constant_values=-1) |
| actual_len = length |
| |
| return { |
| 'features': torch.tensor(features, dtype=torch.float32), |
| 'labels': torch.tensor(labels, dtype=torch.long), |
| 'length': actual_len, |
| 'night_id': night['night_id'], |
| } |
|
|
|
|
| |
| |
| |
| def load_and_preprocess_data(max_seq_len=1200, cache_path='/app/sleep_data.npz', |
| subset_size=500000): |
| """ |
| 从HuggingFace加载数据集并预处理。 |
| |
| 数据集: abmallick/heart-breath-sleep-stage-dataset |
| |
| 特征工程: |
| 1. 直接特征: heart_rate, respiratory_rate |
| 2. HRV特征: hr_rmssd_5 (RMSSD, 反映副交感神经活性) |
| 3. 活动指标: 用hr_slope_3 (心率变化率) 作为体动的替代指标 |
| (数据集没有accelerometer, 但心率快速变化通常与体动相关) |
| 4. Per-patient Z-score标准化 (SleepPPG-Net: 最关键的预处理步骤) |
| """ |
| feature_columns = ['hr_rmssd_5', 'heart_rate', 'respiratory_rate', 'hr_slope_3'] |
| |
| print("=" * 60) |
| print("Loading dataset...") |
| print("=" * 60) |
| |
| |
| if os.path.exists(cache_path): |
| print(f"Loading from cache: {cache_path}") |
| data = np.load(cache_path) |
| hr = data['heart_rate'] |
| rr = data['respiratory_rate'] |
| hrv = data['hr_rmssd_5'] |
| slope = data['hr_slope_3'] |
| night_ids = data['night_id'] |
| raw_stages = data['sleep_stage'] |
| else: |
| print("Loading from HuggingFace Hub (this may take a while)...") |
| from datasets import load_dataset |
| ds = load_dataset("abmallick/heart-breath-sleep-stage-dataset", split="train") |
| |
| |
| ds_sub = ds.select(range(min(subset_size, len(ds)))) |
| hr = np.array(ds_sub['heart_rate'], dtype=np.float32) |
| rr = np.array(ds_sub['respiratory_rate'], dtype=np.float32) |
| hrv = np.array(ds_sub['hr_rmssd_5'], dtype=np.float32) |
| slope = np.array(ds_sub['hr_slope_3'], dtype=np.float32) |
| night_ids = np.array(ds_sub['night_id'], dtype=np.int32) |
| raw_stages = np.array(ds_sub['sleep_stage'], dtype=np.int32) |
| |
| |
| np.savez(cache_path, heart_rate=hr, respiratory_rate=rr, |
| hr_rmssd_5=hrv, hr_slope_3=slope, |
| night_id=night_ids, sleep_stage=raw_stages) |
| print(f"Cached to {cache_path}") |
| |
| print(f"Total records: {len(hr):,}") |
| print(f"Night IDs: {len(np.unique(night_ids))}") |
| |
| |
| |
| |
| |
| |
| unique_stages = np.unique(raw_stages) |
| print(f"Unique raw stages: {unique_stages}") |
| |
| stage_counts = dict(zip(*np.unique(raw_stages, return_counts=True))) |
| print(f"Raw stage distribution:") |
| for s, c in sorted(stage_counts.items()): |
| print(f" Stage {s}: {c:,} ({c/len(raw_stages)*100:.1f}%)") |
| |
| |
| labels = raw_stages.copy() |
| labels[raw_stages == 5] = 0 |
| labels[raw_stages == 9] = 0 |
| |
| has_rem = 4 in unique_stages |
| |
| if has_rem: |
| n_classes = 5 |
| print("5-class classification: Wake, N1, N2, N3, REM") |
| else: |
| |
| n_classes = 4 |
| print("4-class classification: Wake, N1, N2, N3 (no REM in dataset)") |
| |
| mapped_counts = dict(zip(*np.unique(labels, return_counts=True))) |
| print(f"Mapped stage distribution:") |
| for s, c in sorted(mapped_counts.items()): |
| name = STAGE_NAMES[s] if s < len(STAGE_NAMES) else f"Stage{s}" |
| print(f" {name} ({s}): {c:,} ({c/len(labels)*100:.1f}%)") |
| |
| |
| features = np.stack([hrv, hr, rr, slope], axis=-1).astype(np.float32) |
| print(f"\nFeatures shape: {features.shape}") |
| print(f"Selected features: {feature_columns}") |
| |
| |
| for i, name in enumerate(feature_columns): |
| col = features[:, i] |
| nan_mask = np.isnan(col) | np.isinf(col) |
| if nan_mask.any(): |
| median_val = np.nanmedian(col) |
| col[nan_mask] = median_val |
| print(f" Fixed {nan_mask.sum()} NaN/Inf values in {name}") |
| |
| |
| print("\nApplying per-patient Z-score normalization...") |
| features = SleepDataProcessor.per_patient_normalize(features, night_ids) |
| |
| |
| features = np.clip(features, -5.0, 5.0) |
| |
| |
| unique_nights = np.unique(night_ids) |
| rng = np.random.RandomState(42) |
| rng.shuffle(unique_nights) |
| |
| n_total = len(unique_nights) |
| n_train = int(n_total * 0.8) |
| n_val = int(n_total * 0.1) |
| |
| train_nights = set(unique_nights[:n_train]) |
| val_nights = set(unique_nights[n_train:n_train + n_val]) |
| test_nights = set(unique_nights[n_train + n_val:]) |
| |
| print(f"\nData split:") |
| print(f" Train: {len(train_nights)} nights") |
| print(f" Val: {len(val_nights)} nights") |
| print(f" Test: {len(test_nights)} nights") |
| |
| |
| train_mask = np.isin(night_ids, list(train_nights)) |
| val_mask = np.isin(night_ids, list(val_nights)) |
| test_mask = np.isin(night_ids, list(test_nights)) |
| |
| train_dataset = SleepNightDataset( |
| features[train_mask], labels[train_mask], night_ids[train_mask], |
| max_seq_len=max_seq_len, augment=True |
| ) |
| val_dataset = SleepNightDataset( |
| features[val_mask], labels[val_mask], night_ids[val_mask], |
| max_seq_len=max_seq_len, augment=False |
| ) |
| test_dataset = SleepNightDataset( |
| features[test_mask], labels[test_mask], night_ids[test_mask], |
| max_seq_len=max_seq_len, augment=False |
| ) |
| |
| |
| train_labels = labels[train_mask] |
| class_counts = np.bincount(train_labels, minlength=n_classes) |
| total = class_counts.sum() |
| |
| class_weights = total / (n_classes * class_counts + 1e-8) |
| class_weights = np.clip(class_weights / class_weights.min(), 1.0, 5.0) |
| print(f"\nClass weights (inverse frequency): {class_weights.tolist()}") |
| |
| return train_dataset, val_dataset, test_dataset, class_weights, feature_columns, n_classes |
|
|
|
|
| |
| |
| |
| class EarlyStopping: |
| """早停机制 (参考wav2sleep: patience=5)""" |
| def __init__(self, patience=10, min_delta=1e-4): |
| self.patience = patience |
| self.min_delta = min_delta |
| self.counter = 0 |
| self.best_score = None |
| self.early_stop = False |
| self.best_model_state = None |
| |
| def __call__(self, val_score, model): |
| if self.best_score is None or val_score > self.best_score + self.min_delta: |
| self.best_score = val_score |
| self.counter = 0 |
| self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| else: |
| self.counter += 1 |
| if self.counter >= self.patience: |
| self.early_stop = True |
|
|
|
|
| def evaluate(model, dataloader, device, n_classes=5): |
| """评估模型""" |
| model.eval() |
| all_preds = [] |
| all_labels = [] |
| total_loss = 0 |
| n_batches = 0 |
| |
| loss_fn = nn.CrossEntropyLoss(ignore_index=-1) |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| features = batch['features'].to(device) |
| labels = batch['labels'].to(device) |
| lengths = batch['length'] |
| |
| logits = model(features) |
| |
| |
| loss = loss_fn(logits.reshape(-1, n_classes), labels.reshape(-1)) |
| total_loss += loss.item() |
| n_batches += 1 |
| |
| |
| preds = torch.argmax(logits, dim=-1) |
| |
| for i in range(len(lengths)): |
| length = min(lengths[i], logits.size(1)) |
| valid_preds = preds[i, :length].cpu().numpy() |
| valid_labels = labels[i, :length].cpu().numpy() |
| |
| |
| valid_mask = valid_labels >= 0 |
| all_preds.extend(valid_preds[valid_mask].tolist()) |
| all_labels.extend(valid_labels[valid_mask].tolist()) |
| |
| all_preds = np.array(all_preds) |
| all_labels = np.array(all_labels) |
| |
| |
| accuracy = accuracy_score(all_labels, all_preds) |
| f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) |
| f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) |
| kappa = cohen_kappa_score(all_labels, all_preds) |
| avg_loss = total_loss / max(n_batches, 1) |
| |
| return { |
| 'loss': avg_loss, |
| 'accuracy': accuracy, |
| 'f1_macro': f1_macro, |
| 'f1_weighted': f1_weighted, |
| 'kappa': kappa, |
| 'preds': all_preds, |
| 'labels': all_labels, |
| } |
|
|
|
|
| def train( |
| model_config='base', |
| n_features=4, |
| n_classes=5, |
| max_seq_len=1200, |
| batch_size=16, |
| lr=1e-3, |
| weight_decay=1e-2, |
| max_epochs=100, |
| patience=10, |
| warmup_epochs=5, |
| device='auto', |
| save_dir='./checkpoints', |
| project_name='sleep-staging', |
| run_name=None, |
| ): |
| """ |
| 完整训练流程 |
| |
| 超参数来源: |
| - lr=1e-3, weight_decay=1e-2: wav2sleep Section 4.2 |
| - batch_size=16: wav2sleep Section 4.2 |
| - patience=10: 改进自wav2sleep(5), 给更多探索空间 |
| - warmup: wav2sleep使用2000步线性warmup |
| """ |
| |
| if device == 'auto': |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"\nDevice: {device}") |
| |
| |
| torch.manual_seed(42) |
| np.random.seed(42) |
| random.seed(42) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(42) |
| |
| |
| train_dataset, val_dataset, test_dataset, class_weights, feature_names, n_classes_data = \ |
| load_and_preprocess_data(max_seq_len=max_seq_len) |
| n_classes = n_classes_data |
| |
| train_loader = DataLoader( |
| train_dataset, batch_size=batch_size, shuffle=True, |
| num_workers=0, pin_memory=(device == 'cuda'), |
| drop_last=True, |
| ) |
| val_loader = DataLoader( |
| val_dataset, batch_size=batch_size, shuffle=False, |
| num_workers=0, pin_memory=(device == 'cuda'), |
| ) |
| test_loader = DataLoader( |
| test_dataset, batch_size=batch_size, shuffle=False, |
| num_workers=0, pin_memory=(device == 'cuda'), |
| ) |
| |
| |
| model = create_model(model_config, n_features=n_features, n_classes=n_classes) |
| model = model.to(device) |
| |
| |
| loss_fn = WeightedFocalLoss( |
| class_weights=class_weights.tolist(), |
| gamma=2.0, |
| ).to(device) |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=lr, |
| weight_decay=weight_decay, |
| betas=(0.9, 0.999), |
| ) |
| |
| |
| total_steps = len(train_loader) * max_epochs |
| scheduler = OneCycleLR( |
| optimizer, |
| max_lr=lr, |
| total_steps=total_steps, |
| pct_start=0.1, |
| anneal_strategy='cos', |
| div_factor=10, |
| final_div_factor=100, |
| ) |
| |
| |
| early_stopping = EarlyStopping(patience=patience) |
| |
| |
| if HAS_TRACKIO: |
| space_id = os.environ.get('TRACKIO_SPACE_ID', None) |
| if space_id: |
| trackio.init( |
| project=project_name, |
| run=run_name or f"sleepnet_{model_config}_lr{lr}", |
| space_id=space_id, |
| ) |
| |
| |
| os.makedirs(save_dir, exist_ok=True) |
| |
| if run_name is None: |
| run_name = f"sleepnet_{model_config}_lr{lr}_bs{batch_size}" |
| |
| |
| config = { |
| 'model_config': model_config, |
| 'n_features': n_features, |
| 'n_classes': n_classes, |
| 'feature_names': feature_names, |
| 'max_seq_len': max_seq_len, |
| 'batch_size': batch_size, |
| 'lr': lr, |
| 'weight_decay': weight_decay, |
| 'max_epochs': max_epochs, |
| 'patience': patience, |
| 'class_weights': class_weights.tolist(), |
| 'n_parameters': model.count_parameters(), |
| 'device': device, |
| } |
| |
| with open(os.path.join(save_dir, 'config.json'), 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| print(f"\n{'='*60}") |
| print(f"Training: {run_name}") |
| print(f"{'='*60}") |
| print(f" Model: SleepStageNet-{model_config} ({model.count_parameters():,} params)") |
| print(f" Features: {feature_names}") |
| print(f" Train: {len(train_dataset)} nights | Val: {len(val_dataset)} | Test: {len(test_dataset)}") |
| print(f" Batch size: {batch_size} | LR: {lr} | Weight decay: {weight_decay}") |
| print(f" Max epochs: {max_epochs} | Early stopping patience: {patience}") |
| print(f"{'='*60}\n") |
| |
| |
| best_kappa = -1 |
| history = [] |
| |
| for epoch in range(1, max_epochs + 1): |
| |
| model.train() |
| train_loss = 0 |
| n_batches = 0 |
| epoch_start = time.time() |
| |
| for batch_idx, batch in enumerate(train_loader): |
| features = batch['features'].to(device) |
| labels = batch['labels'].to(device) |
| |
| optimizer.zero_grad() |
| logits = model(features) |
| |
| |
| valid_mask = labels.reshape(-1) >= 0 |
| if valid_mask.sum() > 0: |
| valid_logits = logits.reshape(-1, n_classes)[valid_mask] |
| valid_labels = labels.reshape(-1)[valid_mask] |
| loss = loss_fn(valid_logits, valid_labels) |
| else: |
| continue |
| |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| |
| optimizer.step() |
| scheduler.step() |
| |
| train_loss += loss.item() |
| n_batches += 1 |
| |
| |
| if HAS_TRACKIO and (batch_idx + 1) % 10 == 0: |
| trackio.log({ |
| 'train/loss': loss.item(), |
| 'train/lr': scheduler.get_last_lr()[0], |
| }) |
| |
| avg_train_loss = train_loss / max(n_batches, 1) |
| epoch_time = time.time() - epoch_start |
| |
| |
| val_metrics = evaluate(model, val_loader, device, n_classes) |
| |
| |
| epoch_info = { |
| 'epoch': epoch, |
| 'train_loss': avg_train_loss, |
| 'val_loss': val_metrics['loss'], |
| 'val_accuracy': val_metrics['accuracy'], |
| 'val_f1_macro': val_metrics['f1_macro'], |
| 'val_kappa': val_metrics['kappa'], |
| 'lr': scheduler.get_last_lr()[0], |
| 'time': epoch_time, |
| } |
| history.append(epoch_info) |
| |
| |
| print(f"Epoch {epoch:3d}/{max_epochs} | " |
| f"Train Loss: {avg_train_loss:.4f} | " |
| f"Val Loss: {val_metrics['loss']:.4f} | " |
| f"Val Acc: {val_metrics['accuracy']:.4f} | " |
| f"Val F1: {val_metrics['f1_macro']:.4f} | " |
| f"Val κ: {val_metrics['kappa']:.4f} | " |
| f"LR: {scheduler.get_last_lr()[0]:.2e} | " |
| f"Time: {epoch_time:.1f}s") |
| |
| |
| if HAS_TRACKIO: |
| trackio.log({ |
| 'epoch': epoch, |
| 'train/epoch_loss': avg_train_loss, |
| 'val/loss': val_metrics['loss'], |
| 'val/accuracy': val_metrics['accuracy'], |
| 'val/f1_macro': val_metrics['f1_macro'], |
| 'val/f1_weighted': val_metrics['f1_weighted'], |
| 'val/kappa': val_metrics['kappa'], |
| }) |
| |
| |
| if val_metrics['kappa'] > best_kappa: |
| best_kappa = val_metrics['kappa'] |
| print(f" ★ New best κ: {best_kappa:.4f}") |
| |
| |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'val_metrics': val_metrics, |
| 'config': config, |
| }, os.path.join(save_dir, 'best_model.pt')) |
| |
| |
| early_stopping(val_metrics['kappa'], model) |
| if early_stopping.early_stop: |
| print(f"\n⚠ Early stopping at epoch {epoch} (patience={patience})") |
| break |
| |
| |
| print(f"\n{'='*60}") |
| print("Final Evaluation on Test Set") |
| print(f"{'='*60}") |
| |
| |
| if early_stopping.best_model_state is not None: |
| model.load_state_dict(early_stopping.best_model_state) |
| model = model.to(device) |
| |
| test_metrics = evaluate(model, test_loader, device, n_classes) |
| |
| print(f"\nTest Results:") |
| print(f" Accuracy: {test_metrics['accuracy']:.4f}") |
| print(f" F1 (macro): {test_metrics['f1_macro']:.4f}") |
| print(f" F1 (weight): {test_metrics['f1_weighted']:.4f}") |
| print(f" Cohen's κ: {test_metrics['kappa']:.4f}") |
| |
| |
| actual_n_classes = len(set(test_metrics['labels'].tolist()) | set(test_metrics['preds'].tolist())) |
| used_labels = sorted(set(test_metrics['labels'].tolist()) | set(test_metrics['preds'].tolist())) |
| used_names = [STAGE_NAMES[i] if i < len(STAGE_NAMES) else f"Class{i}" for i in used_labels] |
| |
| print(f"\nClassification Report:") |
| print(classification_report( |
| test_metrics['labels'], test_metrics['preds'], |
| labels=used_labels, target_names=used_names, zero_division=0, |
| )) |
| |
| |
| cm = confusion_matrix(test_metrics['labels'], test_metrics['preds'], labels=used_labels) |
| print("Confusion Matrix:") |
| print(f"{'':>8}", end='') |
| for name in used_names: |
| print(f"{name:>8}", end='') |
| print() |
| for i, name in enumerate(used_names): |
| print(f"{name:>8}", end='') |
| for j in range(len(used_names)): |
| print(f"{cm[i,j]:>8}", end='') |
| print() |
| |
| |
| if HAS_TRACKIO: |
| if test_metrics['kappa'] >= 0.60: |
| trackio.alert( |
| "Training Complete - Good Performance", |
| f"κ={test_metrics['kappa']:.4f}, F1={test_metrics['f1_macro']:.4f}, " |
| f"Acc={test_metrics['accuracy']:.4f}. Model is usable for deployment.", |
| level="info" |
| ) |
| elif test_metrics['kappa'] >= 0.40: |
| trackio.alert( |
| "Training Complete - Moderate Performance", |
| f"κ={test_metrics['kappa']:.4f}. Consider: (1) more data, " |
| f"(2) larger model, (3) additional features.", |
| level="warn" |
| ) |
| else: |
| trackio.alert( |
| "Training Complete - Low Performance", |
| f"κ={test_metrics['kappa']:.4f}. Needs investigation: " |
| f"check data quality, feature engineering, or model capacity.", |
| level="error" |
| ) |
| |
| |
| results = { |
| 'test_accuracy': test_metrics['accuracy'], |
| 'test_f1_macro': test_metrics['f1_macro'], |
| 'test_f1_weighted': test_metrics['f1_weighted'], |
| 'test_kappa': test_metrics['kappa'], |
| 'best_val_kappa': best_kappa, |
| 'total_epochs': len(history), |
| 'config': config, |
| 'history': history, |
| } |
| |
| with open(os.path.join(save_dir, 'results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'config': config, |
| 'test_metrics': {k: v for k, v in test_metrics.items() |
| if k not in ('preds', 'labels')}, |
| 'feature_names': feature_names, |
| 'stage_names': STAGE_NAMES, |
| }, os.path.join(save_dir, 'final_model.pt')) |
| |
| print(f"\n✅ Training complete! Models saved to {save_dir}") |
| print(f" Best validation κ: {best_kappa:.4f}") |
| print(f" Test κ: {test_metrics['kappa']:.4f}") |
| |
| return model, test_metrics, history |
|
|
|
|
| |
| |
| |
| def predict_sleep_stages( |
| model_path: str, |
| hrv_sequence: np.ndarray, |
| hr_sequence: np.ndarray, |
| rr_sequence: np.ndarray, |
| movement_sequence: np.ndarray, |
| device: str = 'auto', |
| ) -> dict: |
| """ |
| 使用训练好的模型进行睡眠分期预测。 |
| |
| Args: |
| model_path: 模型文件路径 |
| hrv_sequence: HRV序列 (RMSSD值, 每30秒一个值) |
| hr_sequence: 心率序列 (bpm, 每30秒一个值) |
| rr_sequence: 呼吸频率序列 (breaths/min, 每30秒一个值) |
| movement_sequence: 体动序列 (加速度/活动量, 每30秒一个值) |
| device: 计算设备 |
| |
| Returns: |
| dict: { |
| 'stages': 预测的睡眠分期序列 (0-4), |
| 'stage_names': 分期名称序列, |
| 'probabilities': 每个分期的概率, |
| 'summary': 睡眠摘要统计 |
| } |
| """ |
| if device == 'auto': |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
| |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) |
| config = checkpoint['config'] |
| |
| model = create_model( |
| config['model_config'], |
| n_features=config['n_features'], |
| n_classes=config['n_classes'], |
| ) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model = model.to(device) |
| model.eval() |
| |
| |
| features = np.stack([ |
| hrv_sequence, hr_sequence, rr_sequence, movement_sequence |
| ], axis=-1).astype(np.float32) |
| |
| |
| mean = features.mean(axis=0) |
| std = features.std(axis=0) |
| std[std < 1e-8] = 1.0 |
| features = (features - mean) / std |
| features = np.clip(features, -5.0, 5.0) |
| |
| |
| x = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| logits = model(x) |
| probs = F.softmax(logits, dim=-1) |
| predictions = torch.argmax(logits, dim=-1) |
| |
| stages = predictions[0].cpu().numpy() |
| probabilities = probs[0].cpu().numpy() |
| |
| |
| stage_counts = Counter(stages.tolist()) |
| total_epochs = len(stages) |
| |
| summary = { |
| 'total_time_hours': total_epochs * 30 / 3600, |
| 'sleep_efficiency': (total_epochs - stage_counts.get(0, 0)) / total_epochs * 100, |
| } |
| for i, name in enumerate(STAGE_NAMES): |
| count = stage_counts.get(i, 0) |
| summary[f'{name}_minutes'] = count * 0.5 |
| summary[f'{name}_percent'] = count / total_epochs * 100 |
| |
| return { |
| 'stages': stages, |
| 'stage_names': [STAGE_NAMES[s] for s in stages], |
| 'probabilities': probabilities, |
| 'summary': summary, |
| } |
|
|
|
|
| |
| |
| |
| if __name__ == '__main__': |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Sleep Stage Classification Training') |
| parser.add_argument('--model', type=str, default='base', choices=['small', 'base', 'large']) |
| parser.add_argument('--batch_size', type=int, default=16) |
| parser.add_argument('--lr', type=float, default=1e-3) |
| parser.add_argument('--weight_decay', type=float, default=1e-2) |
| parser.add_argument('--max_epochs', type=int, default=100) |
| parser.add_argument('--patience', type=int, default=10) |
| parser.add_argument('--max_seq_len', type=int, default=1200) |
| parser.add_argument('--save_dir', type=str, default='./checkpoints') |
| parser.add_argument('--device', type=str, default='auto') |
| parser.add_argument('--quick_test', action='store_true', help='Quick test with 3 epochs') |
| |
| args = parser.parse_args() |
| |
| if args.quick_test: |
| args.max_epochs = 3 |
| args.patience = 3 |
| |
| model, metrics, history = train( |
| model_config=args.model, |
| batch_size=args.batch_size, |
| lr=args.lr, |
| weight_decay=args.weight_decay, |
| max_epochs=args.max_epochs, |
| patience=args.patience, |
| max_seq_len=args.max_seq_len, |
| device=args.device, |
| save_dir=args.save_dir, |
| ) |
|
|