File size: 1,184 Bytes
5ad6849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import random
import numpy as np
import torch
from torch import nn


def worker_init_fn(worker_id: int, rank: int = 0):
    global_rank = rank
    process_seed = torch.initial_seed()
    base_seed = process_seed - worker_id
    ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
    np.random.seed(ss.generate_state(4))
    torch_ss, stdlib_ss = ss.spawn(2)
    torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
    stdlib_seed = (
        stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]
    ).sum()
    random.seed(stdlib_seed)


def replace_with_fused_layernorm(module: nn.Module):
    """Replace LayerNorm with apex FusedLayerNorm if available."""
    try:
        from apex.normalization import FusedLayerNorm
    except ImportError:
        return module

    for name, child in module.named_children():
        if isinstance(child, nn.LayerNorm):
            fused_layernorm = FusedLayerNorm(
                child.normalized_shape, child.eps, child.elementwise_affine
            )
            module.register_module(name, fused_layernorm)
        else:
            replace_with_fused_layernorm(child)
    return module