import random import numpy as np import torch from torch import nn def worker_init_fn(worker_id: int, rank: int = 0): global_rank = rank process_seed = torch.initial_seed() base_seed = process_seed - worker_id ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) np.random.seed(ss.generate_state(4)) torch_ss, stdlib_ss = ss.spawn(2) torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) stdlib_seed = ( stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1] ).sum() random.seed(stdlib_seed) def replace_with_fused_layernorm(module: nn.Module): """Replace LayerNorm with apex FusedLayerNorm if available.""" try: from apex.normalization import FusedLayerNorm except ImportError: return module for name, child in module.named_children(): if isinstance(child, nn.LayerNorm): fused_layernorm = FusedLayerNorm( child.normalized_shape, child.eps, child.elementwise_affine ) module.register_module(name, fused_layernorm) else: replace_with_fused_layernorm(child) return module