| """
|
| torch_utils.py
|
|
|
| General utilities for randomness, mixed precision training, and miscellaneous checks in PyTorch.
|
|
|
| Random `set_global_seed` functionality is taken directly from PyTorch-Lighting:
|
| > Ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/seed.py
|
|
|
| This is pretty important to get right if we're every randomly generating our masks (or prefix dropout) inside our
|
| Dataset __getitem__() with multiple workers... if not handled properly, we will get repeated augmentations anytime
|
| we inject randomness from non-PyTorch sources (e.g., numpy, random)!
|
| > Ref: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
|
|
|
| Terminology
|
| -> World Size :: Total number of processes distributed over (# nodes x # devices) -- assumed homogenous!
|
| -> Rank :: Integer index of current process in the total world size
|
| -> Local Rank :: Local index on given node in [0, Devices per Node]
|
| """
|
|
|
| import os
|
| import random
|
| from typing import Callable, Optional
|
|
|
| import numpy as np
|
| import torch
|
|
|
|
|
|
|
|
|
| def set_global_seed(seed: int, get_worker_init_fn: bool = False) -> Optional[Callable[[int], None]]:
|
| """Sets seed for all randomness libraries (mostly random, numpy, torch) and produces a `worker_init_fn`"""
|
| assert np.iinfo(np.uint32).min < seed < np.iinfo(np.uint32).max, "Seed outside the np.uint32 bounds!"
|
|
|
|
|
| os.environ["EXPERIMENT_GLOBAL_SEED"] = str(seed)
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
|
|
| return worker_init_function if get_worker_init_fn else None
|
|
|
|
|
| def worker_init_function(worker_id: int) -> None:
|
| """
|
| Borrowed directly from PyTorch-Lightning; inspired by this issue comment in the PyTorch repo:
|
| > Ref: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
|
|
|
| Intuition: You can think of the seed sequence spawn function as a "janky" torch.Generator() or jax.PRNGKey that
|
| you can run iterative splitting on to get new (predictable) randomness.
|
|
|
| :param worker_id: Identifier for the given worker [0, num_workers) for the Dataloader in question.
|
| """
|
|
|
| global_rank, process_seed = int(os.environ["LOCAL_RANK"]), torch.initial_seed()
|
|
|
|
|
|
|
| base_seed = process_seed - worker_id
|
|
|
|
|
| seed_seq = np.random.SeedSequence([base_seed, worker_id, global_rank])
|
|
|
|
|
| np.random.seed(seed_seq.generate_state(4))
|
|
|
|
|
| torch_seed_seq, random_seed_seq = seed_seq.spawn(2)
|
|
|
|
|
| torch.manual_seed(torch_seed_seq.generate_state(1, dtype=np.uint64)[0])
|
|
|
|
|
| random_seed = (random_seed_seq.generate_state(2, dtype=np.uint64).astype(list) * [1 << 64, 1]).sum()
|
| random.seed(random_seed)
|
|
|
|
|
|
|
|
|
|
|
| def check_bloat16_supported() -> bool:
|
| try:
|
| import packaging.version
|
| import torch.cuda.nccl as nccl
|
| import torch.distributed as dist
|
|
|
| return (
|
| (torch.version.cuda is not None)
|
| and torch.cuda.is_bf16_supported()
|
| and (packaging.version.parse(torch.version.cuda).release >= (11, 0))
|
| and dist.is_nccl_available()
|
| and (nccl.version() >= (2, 10))
|
| )
|
|
|
| except Exception:
|
| return False
|
|
|