cledouxluma commited on
Commit
bd646ad
·
verified ·
1 Parent(s): b92a969

Upload utils/io.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. utils/io.py +43 -0
utils/io.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """I/O utilities for checkpoint management."""
2
+
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Dict, Optional
7
+
8
+
9
+ def save_checkpoint(model: nn.Module, optimizer, scheduler, epoch: int,
10
+ losses: Dict, output_path: str, config: Dict = None):
11
+ """Save training checkpoint."""
12
+ os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
13
+
14
+ state = {
15
+ 'epoch': epoch,
16
+ 'model_state_dict': model.state_dict(),
17
+ 'optimizer_state_dict': optimizer.state_dict(),
18
+ 'scheduler_state_dict': scheduler.state_dict(),
19
+ 'losses': losses,
20
+ }
21
+ if config:
22
+ state['config'] = config
23
+
24
+ torch.save(state, output_path)
25
+ return output_path
26
+
27
+
28
+ def load_checkpoint(model: nn.Module, checkpoint_path: str,
29
+ optimizer=None, scheduler=None,
30
+ strict: bool = True) -> Dict:
31
+ """Load checkpoint and optionally restore optimizer/scheduler state."""
32
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
33
+
34
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
35
+ model.load_state_dict(state_dict, strict=strict)
36
+
37
+ if optimizer and 'optimizer_state_dict' in checkpoint:
38
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
39
+
40
+ if scheduler and 'scheduler_state_dict' in checkpoint:
41
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
42
+
43
+ return checkpoint