"""I/O utilities for checkpoint management.""" import os import torch import torch.nn as nn from typing import Dict, Optional def save_checkpoint(model: nn.Module, optimizer, scheduler, epoch: int, losses: Dict, output_path: str, config: Dict = None): """Save training checkpoint.""" os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) state = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'losses': losses, } if config: state['config'] = config torch.save(state, output_path) return output_path def load_checkpoint(model: nn.Module, checkpoint_path: str, optimizer=None, scheduler=None, strict: bool = True) -> Dict: """Load checkpoint and optionally restore optimizer/scheduler state.""" checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict = checkpoint.get('model_state_dict', checkpoint) model.load_state_dict(state_dict, strict=strict) if optimizer and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if scheduler and 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) return checkpoint