| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field |
| from datetime import timedelta |
| from io import BytesIO |
| from typing import Any, Dict, List |
|
|
| import torch |
| from torch.distributed.checkpoint.stateful import Stateful |
|
|
|
|
| @dataclass |
| class TrainState(Stateful): |
| step: int = 0 |
| skipped_step: int = 0 |
| token: int = 0 |
| elapsed: timedelta = timedelta(0) |
| global_avg_losses: List[float] = field(default_factory=list) |
| global_max_losses: List[float] = field(default_factory=list) |
| log_steps: List[int] = field(default_factory=list) |
|
|
| def state_dict(self) -> Dict[str, Any]: |
| |
| |
| global_avg_losses_bytes = BytesIO() |
| torch.save(self.global_avg_losses, global_avg_losses_bytes) |
| global_max_losses_bytes = BytesIO() |
| torch.save(self.global_max_losses, global_max_losses_bytes) |
| log_steps_bytes = BytesIO() |
| torch.save(self.log_steps, log_steps_bytes) |
| return { |
| "step": torch.tensor(self.step, dtype=torch.int32), |
| "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32), |
| "token": torch.tensor(self.token, dtype=torch.int64), |
| "elapsed": self.elapsed, |
| "global_avg_losses": global_avg_losses_bytes, |
| "global_max_losses": global_max_losses_bytes, |
| "log_steps": log_steps_bytes, |
| } |
|
|
| def load_state_dict(self, state_dict) -> None: |
| self.step = state_dict["step"].item() |
| self.skipped_step = state_dict.get("skipped_step", 0).item() |
| self.token = state_dict["token"].item() |
| self.elapsed = state_dict["elapsed"] |
| state_dict["global_avg_losses"].seek(0) |
| self.global_avg_losses = torch.load( |
| state_dict["global_avg_losses"], weights_only=False |
| ) |
| state_dict["global_max_losses"].seek(0) |
| self.global_max_losses = torch.load( |
| state_dict["global_max_losses"], weights_only=False |
| ) |
| state_dict["log_steps"].seek(0) |
| self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) |
|
|