ProFound / util /convnext_optim.py
Anonymise's picture
add necessary module
45461c9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import optim as optim
import json
def get_num_layer_for_convnext_single(var_name, depths):
"""
Each layer is assigned distinctive layer ids
"""
if var_name.startswith("downsample_layers"):
stage_id = int(var_name.split(".")[1])
layer_id = sum(depths[:stage_id]) + 1
return layer_id
elif var_name.startswith("stages"):
stage_id = int(var_name.split(".")[1])
block_id = int(var_name.split(".")[2])
layer_id = sum(depths[:stage_id]) + block_id + 1
return layer_id
else:
return sum(depths) + 1
def get_num_layer_for_convnext(var_name):
"""
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
consecutive blocks, including possible neighboring downsample layers;
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
"""
num_max_layer = 12
if var_name.startswith("downsample_layers"):
stage_id = int(var_name.split(".")[1])
if stage_id == 0:
layer_id = 0
elif stage_id == 1 or stage_id == 2:
layer_id = stage_id + 1
elif stage_id == 3:
layer_id = 12
return layer_id
elif var_name.startswith("stages"):
stage_id = int(var_name.split(".")[1])
block_id = int(var_name.split(".")[2])
if stage_id == 0 or stage_id == 1:
layer_id = stage_id + 1
elif stage_id == 2:
layer_id = 3 + block_id // 3
elif stage_id == 3:
layer_id = 12
return layer_id
else:
return num_max_layer + 1
class LayerDecayValueAssigner(object):
def __init__(self, values, depths=[3, 3, 27, 3], layer_decay_type="single"):
self.values = values
self.depths = depths
self.layer_decay_type = layer_decay_type
def get_scale(self, layer_id):
return self.values[layer_id]
def get_layer_id(self, var_name):
if self.layer_decay_type == "single":
return get_num_layer_for_convnext_single(var_name, self.depths)
else:
return get_num_layer_for_convnext(var_name)
def get_parameter_groups(
model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None
):
parameter_group_names = {}
parameter_group_vars = {}
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if (
len(param.shape) == 1
or name.endswith(".bias")
or name in skip_list
or name.endswith(".gamma")
or name.endswith(".beta")
):
group_name = "no_decay"
this_weight_decay = 0.0
else:
group_name = "decay"
this_weight_decay = weight_decay
if get_num_layer is not None:
layer_id = get_num_layer(name)
group_name = "layer_%d_%s" % (layer_id, group_name)
else:
layer_id = None
if group_name not in parameter_group_names:
if get_layer_scale is not None:
scale = get_layer_scale(layer_id)
else:
scale = 1.0
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale,
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale,
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
return list(parameter_group_vars.values())