| import torch |
| import copy |
| from typing import List, Dict, Set, Any |
| import itertools |
|
|
| def manipulate_params(cfg, model): |
| weight_decay_norm = 0 |
| weight_decay_embed = 0 |
| defaults = {} |
| defaults["lr"] = cfg.lr |
| defaults["weight_decay"] = cfg.weight_decay |
|
|
| norm_module_types = ( |
| torch.nn.BatchNorm1d, |
| torch.nn.BatchNorm2d, |
| torch.nn.BatchNorm3d, |
| torch.nn.SyncBatchNorm, |
| |
| torch.nn.GroupNorm, |
| torch.nn.InstanceNorm1d, |
| torch.nn.InstanceNorm2d, |
| torch.nn.InstanceNorm3d, |
| torch.nn.LayerNorm, |
| torch.nn.LocalResponseNorm, |
| ) |
|
|
| params_training: List[Dict[str, Any]] = [] |
| params_finetuning: List[Dict[str, Any]] = [] |
| memo: Set[torch.nn.parameter.Parameter] = set() |
|
|
| train_prefixes = ( |
| "patch_embeds", |
| "f_blocks", |
| "a_blocks", |
| "fusion_modules", |
| "smooth_convs", |
| "train_proj_v1", |
| "train_proj_a1", |
| "text_proj", |
| ) |
|
|
| for module_name, module in model.named_modules(): |
| for module_param_name, value in module.named_parameters(recurse=False): |
| if not value.requires_grad: |
| continue |
| if value in memo: |
| continue |
| memo.add(value) |
| hyperparams = copy.copy(defaults) |
| if 'vgg' in module_name or 'vgg' in module_param_name: |
| hyperparams['lr'] *= 0.1 |
| params_finetuning.append({"params": [value], "name": [module_name], **hyperparams}) |
| elif ( |
| 'train' in module_name |
| or 'train' in module_param_name |
| or module_name.startswith(train_prefixes) |
| ): |
| if ( |
| "relative_position_bias_table" in module_param_name |
| or "pos_embed" in module_param_name |
| ): |
| hyperparams["weight_decay"] = 0.0 |
| if isinstance(module, norm_module_types): |
| hyperparams["weight_decay"] = 0.0 |
| if isinstance(module, torch.nn.Embedding): |
| hyperparams["weight_decay"] = 0.0 |
| params_training.append({"params": [value], "name": [module_name], **hyperparams}) |
| else: |
| print('undefined layer type.') |
| raise NotImplementedError |
| final_list = params_training + params_finetuning |
| assert len([p for p in model.parameters() if p.requires_grad]) == len(final_list), 'checksum confirmed not pass.' |
| return final_list |