File size: 1,184 Bytes
5ad6849 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | import random
import numpy as np
import torch
from torch import nn
def worker_init_fn(worker_id: int, rank: int = 0):
global_rank = rank
process_seed = torch.initial_seed()
base_seed = process_seed - worker_id
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
np.random.seed(ss.generate_state(4))
torch_ss, stdlib_ss = ss.spawn(2)
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
stdlib_seed = (
stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]
).sum()
random.seed(stdlib_seed)
def replace_with_fused_layernorm(module: nn.Module):
"""Replace LayerNorm with apex FusedLayerNorm if available."""
try:
from apex.normalization import FusedLayerNorm
except ImportError:
return module
for name, child in module.named_children():
if isinstance(child, nn.LayerNorm):
fused_layernorm = FusedLayerNorm(
child.normalized_shape, child.eps, child.elementwise_affine
)
module.register_module(name, fused_layernorm)
else:
replace_with_fused_layernorm(child)
return module
|