""" 睡眠分期模型完整训练脚本 基于以下SOTA论文的最佳实践: 1. wav2sleep (2411.04644) - 多模态睡眠分期SOTA, AdamW + Linear Warmup + Exp Decay 2. Cross-Modal Transformer (2208.06991) - 跨模态注意力, 加权交叉熵 3. SleepPPG-Net (2202.05735) - Per-patient Z-score标准化 4. Mamba-sleep (2412.15947) - 类别频率加权 数据集: abmallick/heart-breath-sleep-stage-dataset (HuggingFace Hub) - 包含: heart_rate, respiratory_rate, HRV指标(hr_sdnn_5, hr_rmssd_5), 派生特征 - 30秒epoch, 按night_id组织 - 注: 缺少体动数据,使用HR变化率作为替代活动指标 训练配置: - optimizer: AdamW (lr=1e-3, weight_decay=1e-2) [wav2sleep] - scheduler: CosineAnnealing with warmup [改进自wav2sleep的exp decay] - batch_size: 16 (整夜数据) [wav2sleep] - early stopping: patience=10 epochs [wav2sleep: 5, 我们适当放宽] - loss: Weighted Focal Loss [Cross-Modal Transformer权重 + Focal Loss] - augmentation: 随机特征翻转(p=0.5), 随机特征遮蔽(p=0.3) [wav2sleep] """ import os import sys import json import time import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR from collections import Counter from sklearn.metrics import ( accuracy_score, f1_score, cohen_kappa_score, classification_report, confusion_matrix ) # 导入我们的模型 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from sleep_staging_model import ( SleepStageNet, WeightedFocalLoss, SleepDataProcessor, create_model, MODEL_CONFIGS ) # 可选: trackio监控 try: import trackio HAS_TRACKIO = True except ImportError: HAS_TRACKIO = False STAGE_NAMES = ['Wake', 'N1', 'N2', 'N3', 'REM'] # ============================================================================ # 数据集 # ============================================================================ class SleepNightDataset(Dataset): """ 以整夜为单位的睡眠数据集。 每个样本是一整夜的特征序列和对应的睡眠分期标签。 参考 wav2sleep: "padding or truncating each recording to 10h (T=1200)" """ def __init__( self, features: np.ndarray, # (total_epochs, n_features) labels: np.ndarray, # (total_epochs,) night_ids: np.ndarray, # (total_epochs,) max_seq_len: int = 1200, augment: bool = False, ): self.max_seq_len = max_seq_len self.augment = augment # 按night_id分组 self.nights = [] unique_nights = np.unique(night_ids) for nid in unique_nights: mask = night_ids == nid night_features = features[mask] night_labels = labels[mask] if len(night_features) < 10: # 跳过太短的记录 continue self.nights.append({ 'features': night_features, 'labels': night_labels, 'night_id': nid, 'length': len(night_features), }) print(f" Dataset: {len(self.nights)} nights, " f"avg length: {np.mean([n['length'] for n in self.nights]):.0f} epochs") def __len__(self): return len(self.nights) def __getitem__(self, idx): night = self.nights[idx] features = night['features'].copy() labels = night['labels'].copy() length = night['length'] # 数据增强 (参考wav2sleep Section 4.2) if self.augment: # 随机特征翻转 (p=0.5) - "signals were randomly inverted" if random.random() < 0.5: # 只翻转连续特征,不翻转类别 flip_mask = np.random.random(features.shape[1]) > 0.5 features[:, flip_mask] = -features[:, flip_mask] # 添加少量高斯噪声 if random.random() < 0.3: noise = np.random.normal(0, 0.05, features.shape) features = features + noise # Pad或截断到固定长度 if length >= self.max_seq_len: features = features[:self.max_seq_len] labels = labels[:self.max_seq_len] actual_len = self.max_seq_len else: pad_len = self.max_seq_len - length features = np.pad(features, ((0, pad_len), (0, 0)), 'constant') labels = np.pad(labels, (0, pad_len), 'constant', constant_values=-1) actual_len = length return { 'features': torch.tensor(features, dtype=torch.float32), 'labels': torch.tensor(labels, dtype=torch.long), 'length': actual_len, 'night_id': night['night_id'], } # ============================================================================ # 数据加载和预处理 # ============================================================================ def load_and_preprocess_data(max_seq_len=1200, cache_path='/app/sleep_data.npz', subset_size=500000): """ 从HuggingFace加载数据集并预处理。 数据集: abmallick/heart-breath-sleep-stage-dataset 特征工程: 1. 直接特征: heart_rate, respiratory_rate 2. HRV特征: hr_rmssd_5 (RMSSD, 反映副交感神经活性) 3. 活动指标: 用hr_slope_3 (心率变化率) 作为体动的替代指标 (数据集没有accelerometer, 但心率快速变化通常与体动相关) 4. Per-patient Z-score标准化 (SleepPPG-Net: 最关键的预处理步骤) """ feature_columns = ['hr_rmssd_5', 'heart_rate', 'respiratory_rate', 'hr_slope_3'] print("=" * 60) print("Loading dataset...") print("=" * 60) # 尝试从缓存加载 if os.path.exists(cache_path): print(f"Loading from cache: {cache_path}") data = np.load(cache_path) hr = data['heart_rate'] rr = data['respiratory_rate'] hrv = data['hr_rmssd_5'] slope = data['hr_slope_3'] night_ids = data['night_id'] raw_stages = data['sleep_stage'] else: print("Loading from HuggingFace Hub (this may take a while)...") from datasets import load_dataset ds = load_dataset("abmallick/heart-breath-sleep-stage-dataset", split="train") # Take a subset for manageable training ds_sub = ds.select(range(min(subset_size, len(ds)))) hr = np.array(ds_sub['heart_rate'], dtype=np.float32) rr = np.array(ds_sub['respiratory_rate'], dtype=np.float32) hrv = np.array(ds_sub['hr_rmssd_5'], dtype=np.float32) slope = np.array(ds_sub['hr_slope_3'], dtype=np.float32) night_ids = np.array(ds_sub['night_id'], dtype=np.int32) raw_stages = np.array(ds_sub['sleep_stage'], dtype=np.int32) # Cache for next time np.savez(cache_path, heart_rate=hr, respiratory_rate=rr, hr_rmssd_5=hrv, hr_slope_3=slope, night_id=night_ids, sleep_stage=raw_stages) print(f"Cached to {cache_path}") print(f"Total records: {len(hr):,}") print(f"Night IDs: {len(np.unique(night_ids))}") # 睡眠分期标签映射 # 数据集标签: 0=Wake类, 1=N1, 2=N2, 3=N3, 5=Wake(AASM R&K编码), 9=未知 # 标准AASM 5类: Wake(0), N1(1), N2(2), N3(3), REM(4) # 注意: 此数据集没有REM标签(4), 只有0,1,2,3,5,9 # 5=Wake (R&K编码中5=Stage Wake), 9=Movement/Unknown → 合并到Wake unique_stages = np.unique(raw_stages) print(f"Unique raw stages: {unique_stages}") stage_counts = dict(zip(*np.unique(raw_stages, return_counts=True))) print(f"Raw stage distribution:") for s, c in sorted(stage_counts.items()): print(f" Stage {s}: {c:,} ({c/len(raw_stages)*100:.1f}%)") # 映射标签 labels = raw_stages.copy() labels[raw_stages == 5] = 0 # 5 → Wake labels[raw_stages == 9] = 0 # 9 → Wake (movement/unknown) # 检查是否有REM (4) has_rem = 4 in unique_stages if has_rem: n_classes = 5 print("5-class classification: Wake, N1, N2, N3, REM") else: # 没有REM标签 → 4类分类 n_classes = 4 print("4-class classification: Wake, N1, N2, N3 (no REM in dataset)") mapped_counts = dict(zip(*np.unique(labels, return_counts=True))) print(f"Mapped stage distribution:") for s, c in sorted(mapped_counts.items()): name = STAGE_NAMES[s] if s < len(STAGE_NAMES) else f"Stage{s}" print(f" {name} ({s}): {c:,} ({c/len(labels)*100:.1f}%)") # 组合特征: [HRV, HR, RR, Movement_proxy] features = np.stack([hrv, hr, rr, slope], axis=-1).astype(np.float32) print(f"\nFeatures shape: {features.shape}") print(f"Selected features: {feature_columns}") # 处理缺失值和异常值 for i, name in enumerate(feature_columns): col = features[:, i] nan_mask = np.isnan(col) | np.isinf(col) if nan_mask.any(): median_val = np.nanmedian(col) col[nan_mask] = median_val print(f" Fixed {nan_mask.sum()} NaN/Inf values in {name}") # Per-patient Z-score标准化 (SleepPPG-Net: 最关键步骤) print("\nApplying per-patient Z-score normalization...") features = SleepDataProcessor.per_patient_normalize(features, night_ids) # 裁剪极端值 features = np.clip(features, -5.0, 5.0) # 按night_id分割 train/val/test (80/10/10) unique_nights = np.unique(night_ids) rng = np.random.RandomState(42) rng.shuffle(unique_nights) n_total = len(unique_nights) n_train = int(n_total * 0.8) n_val = int(n_total * 0.1) train_nights = set(unique_nights[:n_train]) val_nights = set(unique_nights[n_train:n_train + n_val]) test_nights = set(unique_nights[n_train + n_val:]) print(f"\nData split:") print(f" Train: {len(train_nights)} nights") print(f" Val: {len(val_nights)} nights") print(f" Test: {len(test_nights)} nights") # 创建数据集 train_mask = np.isin(night_ids, list(train_nights)) val_mask = np.isin(night_ids, list(val_nights)) test_mask = np.isin(night_ids, list(test_nights)) train_dataset = SleepNightDataset( features[train_mask], labels[train_mask], night_ids[train_mask], max_seq_len=max_seq_len, augment=True ) val_dataset = SleepNightDataset( features[val_mask], labels[val_mask], night_ids[val_mask], max_seq_len=max_seq_len, augment=False ) test_dataset = SleepNightDataset( features[test_mask], labels[test_mask], night_ids[test_mask], max_seq_len=max_seq_len, augment=False ) # 计算类别权重 (基于训练集) train_labels = labels[train_mask] class_counts = np.bincount(train_labels, minlength=n_classes) total = class_counts.sum() # 逆频率加权 (参考Mamba-sleep) class_weights = total / (n_classes * class_counts + 1e-8) class_weights = np.clip(class_weights / class_weights.min(), 1.0, 5.0) print(f"\nClass weights (inverse frequency): {class_weights.tolist()}") return train_dataset, val_dataset, test_dataset, class_weights, feature_columns, n_classes # ============================================================================ # 训练循环 # ============================================================================ class EarlyStopping: """早停机制 (参考wav2sleep: patience=5)""" def __init__(self, patience=10, min_delta=1e-4): self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_score = None self.early_stop = False self.best_model_state = None def __call__(self, val_score, model): if self.best_score is None or val_score > self.best_score + self.min_delta: self.best_score = val_score self.counter = 0 self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True def evaluate(model, dataloader, device, n_classes=5): """评估模型""" model.eval() all_preds = [] all_labels = [] total_loss = 0 n_batches = 0 loss_fn = nn.CrossEntropyLoss(ignore_index=-1) with torch.no_grad(): for batch in dataloader: features = batch['features'].to(device) labels = batch['labels'].to(device) lengths = batch['length'] logits = model(features) # (batch, seq_len, n_classes) # 计算loss (忽略padding) loss = loss_fn(logits.reshape(-1, n_classes), labels.reshape(-1)) total_loss += loss.item() n_batches += 1 # 收集预测和标签 (只保留有效部分) preds = torch.argmax(logits, dim=-1) # (batch, seq_len) for i in range(len(lengths)): length = min(lengths[i], logits.size(1)) valid_preds = preds[i, :length].cpu().numpy() valid_labels = labels[i, :length].cpu().numpy() # 过滤padding标签 valid_mask = valid_labels >= 0 all_preds.extend(valid_preds[valid_mask].tolist()) all_labels.extend(valid_labels[valid_mask].tolist()) all_preds = np.array(all_preds) all_labels = np.array(all_labels) # 计算指标 accuracy = accuracy_score(all_labels, all_preds) f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0) f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0) kappa = cohen_kappa_score(all_labels, all_preds) avg_loss = total_loss / max(n_batches, 1) return { 'loss': avg_loss, 'accuracy': accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted, 'kappa': kappa, 'preds': all_preds, 'labels': all_labels, } def train( model_config='base', n_features=4, n_classes=5, max_seq_len=1200, batch_size=16, lr=1e-3, weight_decay=1e-2, max_epochs=100, patience=10, warmup_epochs=5, device='auto', save_dir='./checkpoints', project_name='sleep-staging', run_name=None, ): """ 完整训练流程 超参数来源: - lr=1e-3, weight_decay=1e-2: wav2sleep Section 4.2 - batch_size=16: wav2sleep Section 4.2 - patience=10: 改进自wav2sleep(5), 给更多探索空间 - warmup: wav2sleep使用2000步线性warmup """ # 设备选择 if device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"\nDevice: {device}") # 设置随机种子 torch.manual_seed(42) np.random.seed(42) random.seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed_all(42) # 加载数据 train_dataset, val_dataset, test_dataset, class_weights, feature_names, n_classes_data = \ load_and_preprocess_data(max_seq_len=max_seq_len) n_classes = n_classes_data # Use actual number of classes from data train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=(device == 'cuda'), drop_last=True, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=(device == 'cuda'), ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=(device == 'cuda'), ) # 创建模型 model = create_model(model_config, n_features=n_features, n_classes=n_classes) model = model.to(device) # 损失函数 (加权Focal Loss) loss_fn = WeightedFocalLoss( class_weights=class_weights.tolist(), gamma=2.0, ).to(device) # 优化器 (AdamW, 参考wav2sleep) optimizer = torch.optim.AdamW( model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999), ) # 学习率调度器 (OneCycleLR, 结合warmup和cosine decay) total_steps = len(train_loader) * max_epochs scheduler = OneCycleLR( optimizer, max_lr=lr, total_steps=total_steps, pct_start=0.1, # 10% warmup anneal_strategy='cos', div_factor=10, # 初始lr = max_lr/10 final_div_factor=100, # 最终lr = max_lr/1000 ) # 早停 early_stopping = EarlyStopping(patience=patience) # Trackio监控 if HAS_TRACKIO: space_id = os.environ.get('TRACKIO_SPACE_ID', None) if space_id: trackio.init( project=project_name, run=run_name or f"sleepnet_{model_config}_lr{lr}", space_id=space_id, ) # 训练目录 os.makedirs(save_dir, exist_ok=True) if run_name is None: run_name = f"sleepnet_{model_config}_lr{lr}_bs{batch_size}" # 记录配置 config = { 'model_config': model_config, 'n_features': n_features, 'n_classes': n_classes, 'feature_names': feature_names, 'max_seq_len': max_seq_len, 'batch_size': batch_size, 'lr': lr, 'weight_decay': weight_decay, 'max_epochs': max_epochs, 'patience': patience, 'class_weights': class_weights.tolist(), 'n_parameters': model.count_parameters(), 'device': device, } with open(os.path.join(save_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=2) print(f"\n{'='*60}") print(f"Training: {run_name}") print(f"{'='*60}") print(f" Model: SleepStageNet-{model_config} ({model.count_parameters():,} params)") print(f" Features: {feature_names}") print(f" Train: {len(train_dataset)} nights | Val: {len(val_dataset)} | Test: {len(test_dataset)}") print(f" Batch size: {batch_size} | LR: {lr} | Weight decay: {weight_decay}") print(f" Max epochs: {max_epochs} | Early stopping patience: {patience}") print(f"{'='*60}\n") # ===== 训练循环 ===== best_kappa = -1 history = [] for epoch in range(1, max_epochs + 1): # --- 训练 --- model.train() train_loss = 0 n_batches = 0 epoch_start = time.time() for batch_idx, batch in enumerate(train_loader): features = batch['features'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() logits = model(features) # 忽略padding (-1标签) valid_mask = labels.reshape(-1) >= 0 if valid_mask.sum() > 0: valid_logits = logits.reshape(-1, n_classes)[valid_mask] valid_labels = labels.reshape(-1)[valid_mask] loss = loss_fn(valid_logits, valid_labels) else: continue loss.backward() # 梯度裁剪 (防止梯度爆炸) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() train_loss += loss.item() n_batches += 1 # Trackio logging if HAS_TRACKIO and (batch_idx + 1) % 10 == 0: trackio.log({ 'train/loss': loss.item(), 'train/lr': scheduler.get_last_lr()[0], }) avg_train_loss = train_loss / max(n_batches, 1) epoch_time = time.time() - epoch_start # --- 验证 --- val_metrics = evaluate(model, val_loader, device, n_classes) # 记录历史 epoch_info = { 'epoch': epoch, 'train_loss': avg_train_loss, 'val_loss': val_metrics['loss'], 'val_accuracy': val_metrics['accuracy'], 'val_f1_macro': val_metrics['f1_macro'], 'val_kappa': val_metrics['kappa'], 'lr': scheduler.get_last_lr()[0], 'time': epoch_time, } history.append(epoch_info) # 打印进度 print(f"Epoch {epoch:3d}/{max_epochs} | " f"Train Loss: {avg_train_loss:.4f} | " f"Val Loss: {val_metrics['loss']:.4f} | " f"Val Acc: {val_metrics['accuracy']:.4f} | " f"Val F1: {val_metrics['f1_macro']:.4f} | " f"Val κ: {val_metrics['kappa']:.4f} | " f"LR: {scheduler.get_last_lr()[0]:.2e} | " f"Time: {epoch_time:.1f}s") # Trackio logging if HAS_TRACKIO: trackio.log({ 'epoch': epoch, 'train/epoch_loss': avg_train_loss, 'val/loss': val_metrics['loss'], 'val/accuracy': val_metrics['accuracy'], 'val/f1_macro': val_metrics['f1_macro'], 'val/f1_weighted': val_metrics['f1_weighted'], 'val/kappa': val_metrics['kappa'], }) # 检查是否是最佳模型 if val_metrics['kappa'] > best_kappa: best_kappa = val_metrics['kappa'] print(f" ★ New best κ: {best_kappa:.4f}") # 保存最佳模型 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_metrics': val_metrics, 'config': config, }, os.path.join(save_dir, 'best_model.pt')) # 早停检查 early_stopping(val_metrics['kappa'], model) if early_stopping.early_stop: print(f"\n⚠ Early stopping at epoch {epoch} (patience={patience})") break # ===== 测试评估 ===== print(f"\n{'='*60}") print("Final Evaluation on Test Set") print(f"{'='*60}") # 加载最佳模型 if early_stopping.best_model_state is not None: model.load_state_dict(early_stopping.best_model_state) model = model.to(device) test_metrics = evaluate(model, test_loader, device, n_classes) print(f"\nTest Results:") print(f" Accuracy: {test_metrics['accuracy']:.4f}") print(f" F1 (macro): {test_metrics['f1_macro']:.4f}") print(f" F1 (weight): {test_metrics['f1_weighted']:.4f}") print(f" Cohen's κ: {test_metrics['kappa']:.4f}") # 详细分类报告 actual_n_classes = len(set(test_metrics['labels'].tolist()) | set(test_metrics['preds'].tolist())) used_labels = sorted(set(test_metrics['labels'].tolist()) | set(test_metrics['preds'].tolist())) used_names = [STAGE_NAMES[i] if i < len(STAGE_NAMES) else f"Class{i}" for i in used_labels] print(f"\nClassification Report:") print(classification_report( test_metrics['labels'], test_metrics['preds'], labels=used_labels, target_names=used_names, zero_division=0, )) # 混淆矩阵 cm = confusion_matrix(test_metrics['labels'], test_metrics['preds'], labels=used_labels) print("Confusion Matrix:") print(f"{'':>8}", end='') for name in used_names: print(f"{name:>8}", end='') print() for i, name in enumerate(used_names): print(f"{name:>8}", end='') for j in range(len(used_names)): print(f"{cm[i,j]:>8}", end='') print() # Trackio alerts if HAS_TRACKIO: if test_metrics['kappa'] >= 0.60: trackio.alert( "Training Complete - Good Performance", f"κ={test_metrics['kappa']:.4f}, F1={test_metrics['f1_macro']:.4f}, " f"Acc={test_metrics['accuracy']:.4f}. Model is usable for deployment.", level="info" ) elif test_metrics['kappa'] >= 0.40: trackio.alert( "Training Complete - Moderate Performance", f"κ={test_metrics['kappa']:.4f}. Consider: (1) more data, " f"(2) larger model, (3) additional features.", level="warn" ) else: trackio.alert( "Training Complete - Low Performance", f"κ={test_metrics['kappa']:.4f}. Needs investigation: " f"check data quality, feature engineering, or model capacity.", level="error" ) # 保存最终结果 results = { 'test_accuracy': test_metrics['accuracy'], 'test_f1_macro': test_metrics['f1_macro'], 'test_f1_weighted': test_metrics['f1_weighted'], 'test_kappa': test_metrics['kappa'], 'best_val_kappa': best_kappa, 'total_epochs': len(history), 'config': config, 'history': history, } with open(os.path.join(save_dir, 'results.json'), 'w') as f: json.dump(results, f, indent=2) # 保存最终模型 torch.save({ 'model_state_dict': model.state_dict(), 'config': config, 'test_metrics': {k: v for k, v in test_metrics.items() if k not in ('preds', 'labels')}, 'feature_names': feature_names, 'stage_names': STAGE_NAMES, }, os.path.join(save_dir, 'final_model.pt')) print(f"\n✅ Training complete! Models saved to {save_dir}") print(f" Best validation κ: {best_kappa:.4f}") print(f" Test κ: {test_metrics['kappa']:.4f}") return model, test_metrics, history # ============================================================================ # 推理接口 # ============================================================================ def predict_sleep_stages( model_path: str, hrv_sequence: np.ndarray, hr_sequence: np.ndarray, rr_sequence: np.ndarray, movement_sequence: np.ndarray, device: str = 'auto', ) -> dict: """ 使用训练好的模型进行睡眠分期预测。 Args: model_path: 模型文件路径 hrv_sequence: HRV序列 (RMSSD值, 每30秒一个值) hr_sequence: 心率序列 (bpm, 每30秒一个值) rr_sequence: 呼吸频率序列 (breaths/min, 每30秒一个值) movement_sequence: 体动序列 (加速度/活动量, 每30秒一个值) device: 计算设备 Returns: dict: { 'stages': 预测的睡眠分期序列 (0-4), 'stage_names': 分期名称序列, 'probabilities': 每个分期的概率, 'summary': 睡眠摘要统计 } """ if device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' # 加载模型 checkpoint = torch.load(model_path, map_location=device, weights_only=False) config = checkpoint['config'] model = create_model( config['model_config'], n_features=config['n_features'], n_classes=config['n_classes'], ) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) model.eval() # 组合特征 features = np.stack([ hrv_sequence, hr_sequence, rr_sequence, movement_sequence ], axis=-1).astype(np.float32) # Z-score标准化 mean = features.mean(axis=0) std = features.std(axis=0) std[std < 1e-8] = 1.0 features = (features - mean) / std features = np.clip(features, -5.0, 5.0) # 转为tensor x = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device) with torch.no_grad(): logits = model(x) probs = F.softmax(logits, dim=-1) predictions = torch.argmax(logits, dim=-1) stages = predictions[0].cpu().numpy() probabilities = probs[0].cpu().numpy() # 睡眠摘要 stage_counts = Counter(stages.tolist()) total_epochs = len(stages) summary = { 'total_time_hours': total_epochs * 30 / 3600, 'sleep_efficiency': (total_epochs - stage_counts.get(0, 0)) / total_epochs * 100, } for i, name in enumerate(STAGE_NAMES): count = stage_counts.get(i, 0) summary[f'{name}_minutes'] = count * 0.5 # 30秒 = 0.5分钟 summary[f'{name}_percent'] = count / total_epochs * 100 return { 'stages': stages, 'stage_names': [STAGE_NAMES[s] for s in stages], 'probabilities': probabilities, 'summary': summary, } # ============================================================================ # 主入口 # ============================================================================ if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Sleep Stage Classification Training') parser.add_argument('--model', type=str, default='base', choices=['small', 'base', 'large']) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=1e-2) parser.add_argument('--max_epochs', type=int, default=100) parser.add_argument('--patience', type=int, default=10) parser.add_argument('--max_seq_len', type=int, default=1200) parser.add_argument('--save_dir', type=str, default='./checkpoints') parser.add_argument('--device', type=str, default='auto') parser.add_argument('--quick_test', action='store_true', help='Quick test with 3 epochs') args = parser.parse_args() if args.quick_test: args.max_epochs = 3 args.patience = 3 model, metrics, history = train( model_config=args.model, batch_size=args.batch_size, lr=args.lr, weight_decay=args.weight_decay, max_epochs=args.max_epochs, patience=args.patience, max_seq_len=args.max_seq_len, device=args.device, save_dir=args.save_dir, )