bdck commited on
Commit
5ad6849
·
verified ·
1 Parent(s): a3f7a7c

Upload point_sam/utils/torch_utils.py

Browse files
Files changed (1) hide show
  1. point_sam/utils/torch_utils.py +36 -0
point_sam/utils/torch_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ def worker_init_fn(worker_id: int, rank: int = 0):
8
+ global_rank = rank
9
+ process_seed = torch.initial_seed()
10
+ base_seed = process_seed - worker_id
11
+ ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
12
+ np.random.seed(ss.generate_state(4))
13
+ torch_ss, stdlib_ss = ss.spawn(2)
14
+ torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
15
+ stdlib_seed = (
16
+ stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]
17
+ ).sum()
18
+ random.seed(stdlib_seed)
19
+
20
+
21
+ def replace_with_fused_layernorm(module: nn.Module):
22
+ """Replace LayerNorm with apex FusedLayerNorm if available."""
23
+ try:
24
+ from apex.normalization import FusedLayerNorm
25
+ except ImportError:
26
+ return module
27
+
28
+ for name, child in module.named_children():
29
+ if isinstance(child, nn.LayerNorm):
30
+ fused_layernorm = FusedLayerNorm(
31
+ child.normalized_shape, child.eps, child.elementwise_affine
32
+ )
33
+ module.register_module(name, fused_layernorm)
34
+ else:
35
+ replace_with_fused_layernorm(child)
36
+ return module