Upload point_sam/utils/torch_utils.py
Browse files
point_sam/utils/torch_utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def worker_init_fn(worker_id: int, rank: int = 0):
|
| 8 |
+
global_rank = rank
|
| 9 |
+
process_seed = torch.initial_seed()
|
| 10 |
+
base_seed = process_seed - worker_id
|
| 11 |
+
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
|
| 12 |
+
np.random.seed(ss.generate_state(4))
|
| 13 |
+
torch_ss, stdlib_ss = ss.spawn(2)
|
| 14 |
+
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
|
| 15 |
+
stdlib_seed = (
|
| 16 |
+
stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]
|
| 17 |
+
).sum()
|
| 18 |
+
random.seed(stdlib_seed)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def replace_with_fused_layernorm(module: nn.Module):
|
| 22 |
+
"""Replace LayerNorm with apex FusedLayerNorm if available."""
|
| 23 |
+
try:
|
| 24 |
+
from apex.normalization import FusedLayerNorm
|
| 25 |
+
except ImportError:
|
| 26 |
+
return module
|
| 27 |
+
|
| 28 |
+
for name, child in module.named_children():
|
| 29 |
+
if isinstance(child, nn.LayerNorm):
|
| 30 |
+
fused_layernorm = FusedLayerNorm(
|
| 31 |
+
child.normalized_shape, child.eps, child.elementwise_affine
|
| 32 |
+
)
|
| 33 |
+
module.register_module(name, fused_layernorm)
|
| 34 |
+
else:
|
| 35 |
+
replace_with_fused_layernorm(child)
|
| 36 |
+
return module
|