| import random |
| import numpy as np |
| import torch |
| from torch import nn |
|
|
|
|
| def worker_init_fn(worker_id: int, rank: int = 0): |
| global_rank = rank |
| process_seed = torch.initial_seed() |
| base_seed = process_seed - worker_id |
| ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) |
| np.random.seed(ss.generate_state(4)) |
| torch_ss, stdlib_ss = ss.spawn(2) |
| torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) |
| stdlib_seed = ( |
| stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1] |
| ).sum() |
| random.seed(stdlib_seed) |
|
|
|
|
| def replace_with_fused_layernorm(module: nn.Module): |
| """Replace LayerNorm with apex FusedLayerNorm if available.""" |
| try: |
| from apex.normalization import FusedLayerNorm |
| except ImportError: |
| return module |
|
|
| for name, child in module.named_children(): |
| if isinstance(child, nn.LayerNorm): |
| fused_layernorm = FusedLayerNorm( |
| child.normalized_shape, child.eps, child.elementwise_affine |
| ) |
| module.register_module(name, fused_layernorm) |
| else: |
| replace_with_fused_layernorm(child) |
| return module |
|
|