| import math |
| from typing import Dict, List, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
| __all__ = ["init_modules", "load_state_dict"] |
|
|
|
|
| def init_modules( |
| module: Union[nn.Module, List[nn.Module]], init_type="he_fout" |
| ) -> None: |
| init_params = init_type.split("@") |
| if len(init_params) > 1: |
| init_params = float(init_params[1]) |
| else: |
| init_params = None |
|
|
| if isinstance(module, list): |
| for sub_module in module: |
| init_modules(sub_module) |
| else: |
| for m in module.modules(): |
| if isinstance(m, nn.Conv2d): |
| if init_type == "he_fout": |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| m.weight.data.normal_(0, math.sqrt(2.0 / n)) |
| elif init_type.startswith("kaiming_uniform"): |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5)) |
| else: |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5)) |
| if m.bias is not None: |
| m.bias.data.zero_() |
| elif isinstance(m, _BatchNorm): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
| elif isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| m.bias.data.zero_() |
| else: |
| weight = getattr(m, "weight", None) |
| bias = getattr(m, "bias", None) |
| if isinstance(weight, torch.nn.Parameter): |
| nn.init.kaiming_uniform_(m.weight, a=math.sqrt(init_params or 5)) |
| if isinstance(bias, torch.nn.Parameter): |
| bias.data.zero_() |
|
|
|
|
| def load_state_dict( |
| model: nn.Module, state_dict: Dict[str, torch.Tensor], strict=True |
| ) -> None: |
| current_state_dict = model.state_dict() |
| for key in state_dict: |
| if current_state_dict[key].shape != state_dict[key].shape: |
| if strict: |
| raise ValueError( |
| "%s shape mismatch (src=%s, target=%s)" |
| % ( |
| key, |
| list(state_dict[key].shape), |
| list(current_state_dict[key].shape), |
| ) |
| ) |
| else: |
| print( |
| "Skip loading %s due to shape mismatch (src=%s, target=%s)" |
| % ( |
| key, |
| list(state_dict[key].shape), |
| list(current_state_dict[key].shape), |
| ) |
| ) |
| else: |
| current_state_dict[key].copy_(state_dict[key]) |
| model.load_state_dict(current_state_dict) |
|
|