| """I/O utilities for checkpoint management.""" |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| from typing import Dict, Optional |
|
|
|
|
| def save_checkpoint(model: nn.Module, optimizer, scheduler, epoch: int, |
| losses: Dict, output_path: str, config: Dict = None): |
| """Save training checkpoint.""" |
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) |
|
|
| state = { |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'losses': losses, |
| } |
| if config: |
| state['config'] = config |
|
|
| torch.save(state, output_path) |
| return output_path |
|
|
|
|
| def load_checkpoint(model: nn.Module, checkpoint_path: str, |
| optimizer=None, scheduler=None, |
| strict: bool = True) -> Dict: |
| """Load checkpoint and optionally restore optimizer/scheduler state.""" |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
| state_dict = checkpoint.get('model_state_dict', checkpoint) |
| model.load_state_dict(state_dict, strict=strict) |
|
|
| if optimizer and 'optimizer_state_dict' in checkpoint: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
| if scheduler and 'scheduler_state_dict' in checkpoint: |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
| return checkpoint |
|
|