sleep-staging-model / train_sleep_staging.py
Liuciba's picture
Upload complete training script
6530eb1 verified
"""
睡眠分期模型完整训练脚本
基于以下SOTA论文的最佳实践:
1. wav2sleep (2411.04644) - 多模态睡眠分期SOTA, AdamW + Linear Warmup + Exp Decay
2. Cross-Modal Transformer (2208.06991) - 跨模态注意力, 加权交叉熵
3. SleepPPG-Net (2202.05735) - Per-patient Z-score标准化
4. Mamba-sleep (2412.15947) - 类别频率加权
数据集: abmallick/heart-breath-sleep-stage-dataset (HuggingFace Hub)
- 包含: heart_rate, respiratory_rate, HRV指标(hr_sdnn_5, hr_rmssd_5), 派生特征
- 30秒epoch, 按night_id组织
- 注: 缺少体动数据,使用HR变化率作为替代活动指标
训练配置:
- optimizer: AdamW (lr=1e-3, weight_decay=1e-2) [wav2sleep]
- scheduler: CosineAnnealing with warmup [改进自wav2sleep的exp decay]
- batch_size: 16 (整夜数据) [wav2sleep]
- early stopping: patience=10 epochs [wav2sleep: 5, 我们适当放宽]
- loss: Weighted Focal Loss [Cross-Modal Transformer权重 + Focal Loss]
- augmentation: 随机特征翻转(p=0.5), 随机特征遮蔽(p=0.3) [wav2sleep]
"""
import os
import sys
import json
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
from collections import Counter
from sklearn.metrics import (
accuracy_score, f1_score, cohen_kappa_score,
classification_report, confusion_matrix
)
# 导入我们的模型
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from sleep_staging_model import (
SleepStageNet, WeightedFocalLoss, SleepDataProcessor,
create_model, MODEL_CONFIGS
)
# 可选: trackio监控
try:
import trackio
HAS_TRACKIO = True
except ImportError:
HAS_TRACKIO = False
STAGE_NAMES = ['Wake', 'N1', 'N2', 'N3', 'REM']
# ============================================================================
# 数据集
# ============================================================================
class SleepNightDataset(Dataset):
"""
以整夜为单位的睡眠数据集。
每个样本是一整夜的特征序列和对应的睡眠分期标签。
参考 wav2sleep: "padding or truncating each recording to 10h (T=1200)"
"""
def __init__(
self,
features: np.ndarray, # (total_epochs, n_features)
labels: np.ndarray, # (total_epochs,)
night_ids: np.ndarray, # (total_epochs,)
max_seq_len: int = 1200,
augment: bool = False,
):
self.max_seq_len = max_seq_len
self.augment = augment
# 按night_id分组
self.nights = []
unique_nights = np.unique(night_ids)
for nid in unique_nights:
mask = night_ids == nid
night_features = features[mask]
night_labels = labels[mask]
if len(night_features) < 10: # 跳过太短的记录
continue
self.nights.append({
'features': night_features,
'labels': night_labels,
'night_id': nid,
'length': len(night_features),
})
print(f" Dataset: {len(self.nights)} nights, "
f"avg length: {np.mean([n['length'] for n in self.nights]):.0f} epochs")
def __len__(self):
return len(self.nights)
def __getitem__(self, idx):
night = self.nights[idx]
features = night['features'].copy()
labels = night['labels'].copy()
length = night['length']
# 数据增强 (参考wav2sleep Section 4.2)
if self.augment:
# 随机特征翻转 (p=0.5) - "signals were randomly inverted"
if random.random() < 0.5:
# 只翻转连续特征,不翻转类别
flip_mask = np.random.random(features.shape[1]) > 0.5
features[:, flip_mask] = -features[:, flip_mask]
# 添加少量高斯噪声
if random.random() < 0.3:
noise = np.random.normal(0, 0.05, features.shape)
features = features + noise
# Pad或截断到固定长度
if length >= self.max_seq_len:
features = features[:self.max_seq_len]
labels = labels[:self.max_seq_len]
actual_len = self.max_seq_len
else:
pad_len = self.max_seq_len - length
features = np.pad(features, ((0, pad_len), (0, 0)), 'constant')
labels = np.pad(labels, (0, pad_len), 'constant', constant_values=-1)
actual_len = length
return {
'features': torch.tensor(features, dtype=torch.float32),
'labels': torch.tensor(labels, dtype=torch.long),
'length': actual_len,
'night_id': night['night_id'],
}
# ============================================================================
# 数据加载和预处理
# ============================================================================
def load_and_preprocess_data(max_seq_len=1200, cache_path='/app/sleep_data.npz',
subset_size=500000):
"""
从HuggingFace加载数据集并预处理。
数据集: abmallick/heart-breath-sleep-stage-dataset
特征工程:
1. 直接特征: heart_rate, respiratory_rate
2. HRV特征: hr_rmssd_5 (RMSSD, 反映副交感神经活性)
3. 活动指标: 用hr_slope_3 (心率变化率) 作为体动的替代指标
(数据集没有accelerometer, 但心率快速变化通常与体动相关)
4. Per-patient Z-score标准化 (SleepPPG-Net: 最关键的预处理步骤)
"""
feature_columns = ['hr_rmssd_5', 'heart_rate', 'respiratory_rate', 'hr_slope_3']
print("=" * 60)
print("Loading dataset...")
print("=" * 60)
# 尝试从缓存加载
if os.path.exists(cache_path):
print(f"Loading from cache: {cache_path}")
data = np.load(cache_path)
hr = data['heart_rate']
rr = data['respiratory_rate']
hrv = data['hr_rmssd_5']
slope = data['hr_slope_3']
night_ids = data['night_id']
raw_stages = data['sleep_stage']
else:
print("Loading from HuggingFace Hub (this may take a while)...")
from datasets import load_dataset
ds = load_dataset("abmallick/heart-breath-sleep-stage-dataset", split="train")
# Take a subset for manageable training
ds_sub = ds.select(range(min(subset_size, len(ds))))
hr = np.array(ds_sub['heart_rate'], dtype=np.float32)
rr = np.array(ds_sub['respiratory_rate'], dtype=np.float32)
hrv = np.array(ds_sub['hr_rmssd_5'], dtype=np.float32)
slope = np.array(ds_sub['hr_slope_3'], dtype=np.float32)
night_ids = np.array(ds_sub['night_id'], dtype=np.int32)
raw_stages = np.array(ds_sub['sleep_stage'], dtype=np.int32)
# Cache for next time
np.savez(cache_path, heart_rate=hr, respiratory_rate=rr,
hr_rmssd_5=hrv, hr_slope_3=slope,
night_id=night_ids, sleep_stage=raw_stages)
print(f"Cached to {cache_path}")
print(f"Total records: {len(hr):,}")
print(f"Night IDs: {len(np.unique(night_ids))}")
# 睡眠分期标签映射
# 数据集标签: 0=Wake类, 1=N1, 2=N2, 3=N3, 5=Wake(AASM R&K编码), 9=未知
# 标准AASM 5类: Wake(0), N1(1), N2(2), N3(3), REM(4)
# 注意: 此数据集没有REM标签(4), 只有0,1,2,3,5,9
# 5=Wake (R&K编码中5=Stage Wake), 9=Movement/Unknown → 合并到Wake
unique_stages = np.unique(raw_stages)
print(f"Unique raw stages: {unique_stages}")
stage_counts = dict(zip(*np.unique(raw_stages, return_counts=True)))
print(f"Raw stage distribution:")
for s, c in sorted(stage_counts.items()):
print(f" Stage {s}: {c:,} ({c/len(raw_stages)*100:.1f}%)")
# 映射标签
labels = raw_stages.copy()
labels[raw_stages == 5] = 0 # 5 → Wake
labels[raw_stages == 9] = 0 # 9 → Wake (movement/unknown)
# 检查是否有REM (4)
has_rem = 4 in unique_stages
if has_rem:
n_classes = 5
print("5-class classification: Wake, N1, N2, N3, REM")
else:
# 没有REM标签 → 4类分类
n_classes = 4
print("4-class classification: Wake, N1, N2, N3 (no REM in dataset)")
mapped_counts = dict(zip(*np.unique(labels, return_counts=True)))
print(f"Mapped stage distribution:")
for s, c in sorted(mapped_counts.items()):
name = STAGE_NAMES[s] if s < len(STAGE_NAMES) else f"Stage{s}"
print(f" {name} ({s}): {c:,} ({c/len(labels)*100:.1f}%)")
# 组合特征: [HRV, HR, RR, Movement_proxy]
features = np.stack([hrv, hr, rr, slope], axis=-1).astype(np.float32)
print(f"\nFeatures shape: {features.shape}")
print(f"Selected features: {feature_columns}")
# 处理缺失值和异常值
for i, name in enumerate(feature_columns):
col = features[:, i]
nan_mask = np.isnan(col) | np.isinf(col)
if nan_mask.any():
median_val = np.nanmedian(col)
col[nan_mask] = median_val
print(f" Fixed {nan_mask.sum()} NaN/Inf values in {name}")
# Per-patient Z-score标准化 (SleepPPG-Net: 最关键步骤)
print("\nApplying per-patient Z-score normalization...")
features = SleepDataProcessor.per_patient_normalize(features, night_ids)
# 裁剪极端值
features = np.clip(features, -5.0, 5.0)
# 按night_id分割 train/val/test (80/10/10)
unique_nights = np.unique(night_ids)
rng = np.random.RandomState(42)
rng.shuffle(unique_nights)
n_total = len(unique_nights)
n_train = int(n_total * 0.8)
n_val = int(n_total * 0.1)
train_nights = set(unique_nights[:n_train])
val_nights = set(unique_nights[n_train:n_train + n_val])
test_nights = set(unique_nights[n_train + n_val:])
print(f"\nData split:")
print(f" Train: {len(train_nights)} nights")
print(f" Val: {len(val_nights)} nights")
print(f" Test: {len(test_nights)} nights")
# 创建数据集
train_mask = np.isin(night_ids, list(train_nights))
val_mask = np.isin(night_ids, list(val_nights))
test_mask = np.isin(night_ids, list(test_nights))
train_dataset = SleepNightDataset(
features[train_mask], labels[train_mask], night_ids[train_mask],
max_seq_len=max_seq_len, augment=True
)
val_dataset = SleepNightDataset(
features[val_mask], labels[val_mask], night_ids[val_mask],
max_seq_len=max_seq_len, augment=False
)
test_dataset = SleepNightDataset(
features[test_mask], labels[test_mask], night_ids[test_mask],
max_seq_len=max_seq_len, augment=False
)
# 计算类别权重 (基于训练集)
train_labels = labels[train_mask]
class_counts = np.bincount(train_labels, minlength=n_classes)
total = class_counts.sum()
# 逆频率加权 (参考Mamba-sleep)
class_weights = total / (n_classes * class_counts + 1e-8)
class_weights = np.clip(class_weights / class_weights.min(), 1.0, 5.0)
print(f"\nClass weights (inverse frequency): {class_weights.tolist()}")
return train_dataset, val_dataset, test_dataset, class_weights, feature_columns, n_classes
# ============================================================================
# 训练循环
# ============================================================================
class EarlyStopping:
"""早停机制 (参考wav2sleep: patience=5)"""
def __init__(self, patience=10, min_delta=1e-4):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = None
self.early_stop = False
self.best_model_state = None
def __call__(self, val_score, model):
if self.best_score is None or val_score > self.best_score + self.min_delta:
self.best_score = val_score
self.counter = 0
self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
def evaluate(model, dataloader, device, n_classes=5):
"""评估模型"""
model.eval()
all_preds = []
all_labels = []
total_loss = 0
n_batches = 0
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
with torch.no_grad():
for batch in dataloader:
features = batch['features'].to(device)
labels = batch['labels'].to(device)
lengths = batch['length']
logits = model(features) # (batch, seq_len, n_classes)
# 计算loss (忽略padding)
loss = loss_fn(logits.reshape(-1, n_classes), labels.reshape(-1))
total_loss += loss.item()
n_batches += 1
# 收集预测和标签 (只保留有效部分)
preds = torch.argmax(logits, dim=-1) # (batch, seq_len)
for i in range(len(lengths)):
length = min(lengths[i], logits.size(1))
valid_preds = preds[i, :length].cpu().numpy()
valid_labels = labels[i, :length].cpu().numpy()
# 过滤padding标签
valid_mask = valid_labels >= 0
all_preds.extend(valid_preds[valid_mask].tolist())
all_labels.extend(valid_labels[valid_mask].tolist())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
# 计算指标
accuracy = accuracy_score(all_labels, all_preds)
f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
kappa = cohen_kappa_score(all_labels, all_preds)
avg_loss = total_loss / max(n_batches, 1)
return {
'loss': avg_loss,
'accuracy': accuracy,
'f1_macro': f1_macro,
'f1_weighted': f1_weighted,
'kappa': kappa,
'preds': all_preds,
'labels': all_labels,
}
def train(
model_config='base',
n_features=4,
n_classes=5,
max_seq_len=1200,
batch_size=16,
lr=1e-3,
weight_decay=1e-2,
max_epochs=100,
patience=10,
warmup_epochs=5,
device='auto',
save_dir='./checkpoints',
project_name='sleep-staging',
run_name=None,
):
"""
完整训练流程
超参数来源:
- lr=1e-3, weight_decay=1e-2: wav2sleep Section 4.2
- batch_size=16: wav2sleep Section 4.2
- patience=10: 改进自wav2sleep(5), 给更多探索空间
- warmup: wav2sleep使用2000步线性warmup
"""
# 设备选择
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nDevice: {device}")
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
# 加载数据
train_dataset, val_dataset, test_dataset, class_weights, feature_names, n_classes_data = \
load_and_preprocess_data(max_seq_len=max_seq_len)
n_classes = n_classes_data # Use actual number of classes from data
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True,
num_workers=0, pin_memory=(device == 'cuda'),
drop_last=True,
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, shuffle=False,
num_workers=0, pin_memory=(device == 'cuda'),
)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False,
num_workers=0, pin_memory=(device == 'cuda'),
)
# 创建模型
model = create_model(model_config, n_features=n_features, n_classes=n_classes)
model = model.to(device)
# 损失函数 (加权Focal Loss)
loss_fn = WeightedFocalLoss(
class_weights=class_weights.tolist(),
gamma=2.0,
).to(device)
# 优化器 (AdamW, 参考wav2sleep)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
betas=(0.9, 0.999),
)
# 学习率调度器 (OneCycleLR, 结合warmup和cosine decay)
total_steps = len(train_loader) * max_epochs
scheduler = OneCycleLR(
optimizer,
max_lr=lr,
total_steps=total_steps,
pct_start=0.1, # 10% warmup
anneal_strategy='cos',
div_factor=10, # 初始lr = max_lr/10
final_div_factor=100, # 最终lr = max_lr/1000
)
# 早停
early_stopping = EarlyStopping(patience=patience)
# Trackio监控
if HAS_TRACKIO:
space_id = os.environ.get('TRACKIO_SPACE_ID', None)
if space_id:
trackio.init(
project=project_name,
run=run_name or f"sleepnet_{model_config}_lr{lr}",
space_id=space_id,
)
# 训练目录
os.makedirs(save_dir, exist_ok=True)
if run_name is None:
run_name = f"sleepnet_{model_config}_lr{lr}_bs{batch_size}"
# 记录配置
config = {
'model_config': model_config,
'n_features': n_features,
'n_classes': n_classes,
'feature_names': feature_names,
'max_seq_len': max_seq_len,
'batch_size': batch_size,
'lr': lr,
'weight_decay': weight_decay,
'max_epochs': max_epochs,
'patience': patience,
'class_weights': class_weights.tolist(),
'n_parameters': model.count_parameters(),
'device': device,
}
with open(os.path.join(save_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=2)
print(f"\n{'='*60}")
print(f"Training: {run_name}")
print(f"{'='*60}")
print(f" Model: SleepStageNet-{model_config} ({model.count_parameters():,} params)")
print(f" Features: {feature_names}")
print(f" Train: {len(train_dataset)} nights | Val: {len(val_dataset)} | Test: {len(test_dataset)}")
print(f" Batch size: {batch_size} | LR: {lr} | Weight decay: {weight_decay}")
print(f" Max epochs: {max_epochs} | Early stopping patience: {patience}")
print(f"{'='*60}\n")
# ===== 训练循环 =====
best_kappa = -1
history = []
for epoch in range(1, max_epochs + 1):
# --- 训练 ---
model.train()
train_loss = 0
n_batches = 0
epoch_start = time.time()
for batch_idx, batch in enumerate(train_loader):
features = batch['features'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
logits = model(features)
# 忽略padding (-1标签)
valid_mask = labels.reshape(-1) >= 0
if valid_mask.sum() > 0:
valid_logits = logits.reshape(-1, n_classes)[valid_mask]
valid_labels = labels.reshape(-1)[valid_mask]
loss = loss_fn(valid_logits, valid_labels)
else:
continue
loss.backward()
# 梯度裁剪 (防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_loss += loss.item()
n_batches += 1
# Trackio logging
if HAS_TRACKIO and (batch_idx + 1) % 10 == 0:
trackio.log({
'train/loss': loss.item(),
'train/lr': scheduler.get_last_lr()[0],
})
avg_train_loss = train_loss / max(n_batches, 1)
epoch_time = time.time() - epoch_start
# --- 验证 ---
val_metrics = evaluate(model, val_loader, device, n_classes)
# 记录历史
epoch_info = {
'epoch': epoch,
'train_loss': avg_train_loss,
'val_loss': val_metrics['loss'],
'val_accuracy': val_metrics['accuracy'],
'val_f1_macro': val_metrics['f1_macro'],
'val_kappa': val_metrics['kappa'],
'lr': scheduler.get_last_lr()[0],
'time': epoch_time,
}
history.append(epoch_info)
# 打印进度
print(f"Epoch {epoch:3d}/{max_epochs} | "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {val_metrics['loss']:.4f} | "
f"Val Acc: {val_metrics['accuracy']:.4f} | "
f"Val F1: {val_metrics['f1_macro']:.4f} | "
f"Val κ: {val_metrics['kappa']:.4f} | "
f"LR: {scheduler.get_last_lr()[0]:.2e} | "
f"Time: {epoch_time:.1f}s")
# Trackio logging
if HAS_TRACKIO:
trackio.log({
'epoch': epoch,
'train/epoch_loss': avg_train_loss,
'val/loss': val_metrics['loss'],
'val/accuracy': val_metrics['accuracy'],
'val/f1_macro': val_metrics['f1_macro'],
'val/f1_weighted': val_metrics['f1_weighted'],
'val/kappa': val_metrics['kappa'],
})
# 检查是否是最佳模型
if val_metrics['kappa'] > best_kappa:
best_kappa = val_metrics['kappa']
print(f" ★ New best κ: {best_kappa:.4f}")
# 保存最佳模型
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_metrics': val_metrics,
'config': config,
}, os.path.join(save_dir, 'best_model.pt'))
# 早停检查
early_stopping(val_metrics['kappa'], model)
if early_stopping.early_stop:
print(f"\n⚠ Early stopping at epoch {epoch} (patience={patience})")
break
# ===== 测试评估 =====
print(f"\n{'='*60}")
print("Final Evaluation on Test Set")
print(f"{'='*60}")
# 加载最佳模型
if early_stopping.best_model_state is not None:
model.load_state_dict(early_stopping.best_model_state)
model = model.to(device)
test_metrics = evaluate(model, test_loader, device, n_classes)
print(f"\nTest Results:")
print(f" Accuracy: {test_metrics['accuracy']:.4f}")
print(f" F1 (macro): {test_metrics['f1_macro']:.4f}")
print(f" F1 (weight): {test_metrics['f1_weighted']:.4f}")
print(f" Cohen's κ: {test_metrics['kappa']:.4f}")
# 详细分类报告
actual_n_classes = len(set(test_metrics['labels'].tolist()) | set(test_metrics['preds'].tolist()))
used_labels = sorted(set(test_metrics['labels'].tolist()) | set(test_metrics['preds'].tolist()))
used_names = [STAGE_NAMES[i] if i < len(STAGE_NAMES) else f"Class{i}" for i in used_labels]
print(f"\nClassification Report:")
print(classification_report(
test_metrics['labels'], test_metrics['preds'],
labels=used_labels, target_names=used_names, zero_division=0,
))
# 混淆矩阵
cm = confusion_matrix(test_metrics['labels'], test_metrics['preds'], labels=used_labels)
print("Confusion Matrix:")
print(f"{'':>8}", end='')
for name in used_names:
print(f"{name:>8}", end='')
print()
for i, name in enumerate(used_names):
print(f"{name:>8}", end='')
for j in range(len(used_names)):
print(f"{cm[i,j]:>8}", end='')
print()
# Trackio alerts
if HAS_TRACKIO:
if test_metrics['kappa'] >= 0.60:
trackio.alert(
"Training Complete - Good Performance",
f"κ={test_metrics['kappa']:.4f}, F1={test_metrics['f1_macro']:.4f}, "
f"Acc={test_metrics['accuracy']:.4f}. Model is usable for deployment.",
level="info"
)
elif test_metrics['kappa'] >= 0.40:
trackio.alert(
"Training Complete - Moderate Performance",
f"κ={test_metrics['kappa']:.4f}. Consider: (1) more data, "
f"(2) larger model, (3) additional features.",
level="warn"
)
else:
trackio.alert(
"Training Complete - Low Performance",
f"κ={test_metrics['kappa']:.4f}. Needs investigation: "
f"check data quality, feature engineering, or model capacity.",
level="error"
)
# 保存最终结果
results = {
'test_accuracy': test_metrics['accuracy'],
'test_f1_macro': test_metrics['f1_macro'],
'test_f1_weighted': test_metrics['f1_weighted'],
'test_kappa': test_metrics['kappa'],
'best_val_kappa': best_kappa,
'total_epochs': len(history),
'config': config,
'history': history,
}
with open(os.path.join(save_dir, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)
# 保存最终模型
torch.save({
'model_state_dict': model.state_dict(),
'config': config,
'test_metrics': {k: v for k, v in test_metrics.items()
if k not in ('preds', 'labels')},
'feature_names': feature_names,
'stage_names': STAGE_NAMES,
}, os.path.join(save_dir, 'final_model.pt'))
print(f"\n✅ Training complete! Models saved to {save_dir}")
print(f" Best validation κ: {best_kappa:.4f}")
print(f" Test κ: {test_metrics['kappa']:.4f}")
return model, test_metrics, history
# ============================================================================
# 推理接口
# ============================================================================
def predict_sleep_stages(
model_path: str,
hrv_sequence: np.ndarray,
hr_sequence: np.ndarray,
rr_sequence: np.ndarray,
movement_sequence: np.ndarray,
device: str = 'auto',
) -> dict:
"""
使用训练好的模型进行睡眠分期预测。
Args:
model_path: 模型文件路径
hrv_sequence: HRV序列 (RMSSD值, 每30秒一个值)
hr_sequence: 心率序列 (bpm, 每30秒一个值)
rr_sequence: 呼吸频率序列 (breaths/min, 每30秒一个值)
movement_sequence: 体动序列 (加速度/活动量, 每30秒一个值)
device: 计算设备
Returns:
dict: {
'stages': 预测的睡眠分期序列 (0-4),
'stage_names': 分期名称序列,
'probabilities': 每个分期的概率,
'summary': 睡眠摘要统计
}
"""
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 加载模型
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
config = checkpoint['config']
model = create_model(
config['model_config'],
n_features=config['n_features'],
n_classes=config['n_classes'],
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
# 组合特征
features = np.stack([
hrv_sequence, hr_sequence, rr_sequence, movement_sequence
], axis=-1).astype(np.float32)
# Z-score标准化
mean = features.mean(axis=0)
std = features.std(axis=0)
std[std < 1e-8] = 1.0
features = (features - mean) / std
features = np.clip(features, -5.0, 5.0)
# 转为tensor
x = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=-1)
predictions = torch.argmax(logits, dim=-1)
stages = predictions[0].cpu().numpy()
probabilities = probs[0].cpu().numpy()
# 睡眠摘要
stage_counts = Counter(stages.tolist())
total_epochs = len(stages)
summary = {
'total_time_hours': total_epochs * 30 / 3600,
'sleep_efficiency': (total_epochs - stage_counts.get(0, 0)) / total_epochs * 100,
}
for i, name in enumerate(STAGE_NAMES):
count = stage_counts.get(i, 0)
summary[f'{name}_minutes'] = count * 0.5 # 30秒 = 0.5分钟
summary[f'{name}_percent'] = count / total_epochs * 100
return {
'stages': stages,
'stage_names': [STAGE_NAMES[s] for s in stages],
'probabilities': probabilities,
'summary': summary,
}
# ============================================================================
# 主入口
# ============================================================================
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Sleep Stage Classification Training')
parser.add_argument('--model', type=str, default='base', choices=['small', 'base', 'large'])
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=1e-2)
parser.add_argument('--max_epochs', type=int, default=100)
parser.add_argument('--patience', type=int, default=10)
parser.add_argument('--max_seq_len', type=int, default=1200)
parser.add_argument('--save_dir', type=str, default='./checkpoints')
parser.add_argument('--device', type=str, default='auto')
parser.add_argument('--quick_test', action='store_true', help='Quick test with 3 epochs')
args = parser.parse_args()
if args.quick_test:
args.max_epochs = 3
args.patience = 3
model, metrics, history = train(
model_config=args.model,
batch_size=args.batch_size,
lr=args.lr,
weight_decay=args.weight_decay,
max_epochs=args.max_epochs,
patience=args.patience,
max_seq_len=args.max_seq_len,
device=args.device,
save_dir=args.save_dir,
)