File size: 1,423 Bytes
bd646ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | """I/O utilities for checkpoint management."""
import os
import torch
import torch.nn as nn
from typing import Dict, Optional
def save_checkpoint(model: nn.Module, optimizer, scheduler, epoch: int,
losses: Dict, output_path: str, config: Dict = None):
"""Save training checkpoint."""
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'losses': losses,
}
if config:
state['config'] = config
torch.save(state, output_path)
return output_path
def load_checkpoint(model: nn.Module, checkpoint_path: str,
optimizer=None, scheduler=None,
strict: bool = True) -> Dict:
"""Load checkpoint and optionally restore optimizer/scheduler state."""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(state_dict, strict=strict)
if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return checkpoint
|