| |
| |
| |
| |
| |
| |
|
|
| import os |
| import re |
| import socket |
| import torch |
| import torch.distributed |
| from . import training_stats |
|
|
| _sync_device = None |
|
|
| |
|
|
| def init(): |
| global _sync_device |
|
|
| if not torch.distributed.is_initialized(): |
| |
| |
| if 'MASTER_ADDR' not in os.environ: |
| os.environ['MASTER_ADDR'] = 'localhost' |
| if 'MASTER_PORT' not in os.environ: |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.bind(('', 0)) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| os.environ['MASTER_PORT'] = str(s.getsockname()[1]) |
| s.close() |
| if 'RANK' not in os.environ: |
| os.environ['RANK'] = '0' |
| if 'LOCAL_RANK' not in os.environ: |
| os.environ['LOCAL_RANK'] = '0' |
| if 'WORLD_SIZE' not in os.environ: |
| os.environ['WORLD_SIZE'] = '1' |
| backend = 'gloo' if os.name == 'nt' else 'nccl' |
| torch.distributed.init_process_group(backend=backend, init_method='env://') |
| torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) |
|
|
| _sync_device = torch.device('cuda') if get_world_size() > 1 else None |
| training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device) |
|
|
| |
|
|
| def get_rank(): |
| return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 |
|
|
| |
|
|
| def get_world_size(): |
| return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 |
|
|
| |
|
|
| def should_stop(): |
| return False |
|
|
| |
|
|
| def should_suspend(): |
| return False |
|
|
| |
|
|
| def request_suspend(): |
| pass |
|
|
| |
|
|
| def update_progress(cur, total): |
| pass |
|
|
| |
|
|
| def print0(*args, **kwargs): |
| if get_rank() == 0: |
| print(*args, **kwargs) |
|
|
| |
|
|
| class CheckpointIO: |
| def __init__(self, **kwargs): |
| self._state_objs = kwargs |
|
|
| def save(self, pt_path, verbose=True): |
| if verbose: |
| print0(f'Saving {pt_path} ... ', end='', flush=True) |
| data = dict() |
| for name, obj in self._state_objs.items(): |
| if obj is None: |
| data[name] = None |
| elif isinstance(obj, dict): |
| data[name] = obj |
| elif hasattr(obj, 'state_dict'): |
| data[name] = obj.state_dict() |
| elif hasattr(obj, '__getstate__'): |
| data[name] = obj.__getstate__() |
| elif hasattr(obj, '__dict__'): |
| data[name] = obj.__dict__ |
| else: |
| raise ValueError(f'Invalid state object of type {type(obj).__name__}') |
| if get_rank() == 0: |
| torch.save(data, pt_path) |
| if verbose: |
| print0('done') |
|
|
| def load(self, pt_path, verbose=True): |
| if verbose: |
| print0(f'Loading {pt_path} ... ', end='', flush=True) |
| data = torch.load(pt_path, map_location=torch.device('cpu')) |
| for name, obj in self._state_objs.items(): |
| if obj is None: |
| pass |
| elif isinstance(obj, dict): |
| obj.clear() |
| obj.update(data[name]) |
| elif hasattr(obj, 'load_state_dict'): |
| obj.load_state_dict(data[name]) |
| elif hasattr(obj, '__setstate__'): |
| obj.__setstate__(data[name]) |
| elif hasattr(obj, '__dict__'): |
| obj.__dict__.clear() |
| obj.__dict__.update(data[name]) |
| else: |
| raise ValueError(f'Invalid state object of type {type(obj).__name__}') |
| if verbose: |
| print0('done') |
|
|
| def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True): |
| fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)] |
| if len(fnames) == 0: |
| return None |
| pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1)))) |
| self.load(pt_path, verbose=verbose) |
| return pt_path |
|
|
| |
|
|