File size: 1,423 Bytes
bd646ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""I/O utilities for checkpoint management."""

import os
import torch
import torch.nn as nn
from typing import Dict, Optional


def save_checkpoint(model: nn.Module, optimizer, scheduler, epoch: int,
                    losses: Dict, output_path: str, config: Dict = None):
    """Save training checkpoint."""
    os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)

    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'losses': losses,
    }
    if config:
        state['config'] = config

    torch.save(state, output_path)
    return output_path


def load_checkpoint(model: nn.Module, checkpoint_path: str,
                    optimizer=None, scheduler=None,
                    strict: bool = True) -> Dict:
    """Load checkpoint and optionally restore optimizer/scheduler state."""
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    state_dict = checkpoint.get('model_state_dict', checkpoint)
    model.load_state_dict(state_dict, strict=strict)

    if optimizer and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    return checkpoint