point-sam-inference / point_sam /utils /torch_utils.py
bdck's picture
Upload point_sam/utils/torch_utils.py
5ad6849 verified
import random
import numpy as np
import torch
from torch import nn
def worker_init_fn(worker_id: int, rank: int = 0):
global_rank = rank
process_seed = torch.initial_seed()
base_seed = process_seed - worker_id
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
np.random.seed(ss.generate_state(4))
torch_ss, stdlib_ss = ss.spawn(2)
torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0])
stdlib_seed = (
stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]
).sum()
random.seed(stdlib_seed)
def replace_with_fused_layernorm(module: nn.Module):
"""Replace LayerNorm with apex FusedLayerNorm if available."""
try:
from apex.normalization import FusedLayerNorm
except ImportError:
return module
for name, child in module.named_children():
if isinstance(child, nn.LayerNorm):
fused_layernorm = FusedLayerNorm(
child.normalized_shape, child.eps, child.elementwise_affine
)
module.register_module(name, fused_layernorm)
else:
replace_with_fused_layernorm(child)
return module