ProFound / models /build_classification.py
Anonymise's picture
add necessary module
45461c9
from models.classifier import Classifier
from models.convnextv2 import convnextv2_tiny, remap_checkpoint_keys, load_state_dict
from util.lars import LARS
import torch
import os
from util.convnext_optim import get_parameter_groups, LayerDecayValueAssigner
def build_model(args, device):
if args.model == "profound_conv":
convnext = convnextv2_tiny(in_chans=3, drop_path_rate=0.1)
if args.pretrain is None:
raise NotImplementedError(f"No pretrained weight")
if not os.path.exists(args.pretrain):
raise FileExistsError(f"{args.pretrain} Not exists")
ckpt = torch.load(args.pretrain, map_location="cpu")
ckpt = remap_checkpoint_keys(ckpt)
load_state_dict(convnext, ckpt, weights_only=False)
model = Classifier(convnext, args.num_classes)
model = model.to(device)
if args.train == "freeze":
for key, value in model.encoder.named_parameters():
value.requires_grad = False
optimizer = LARS(model.head.parameters(), weight_decay=0, lr=args.lr)
else:
num_layers = sum(convnext.depths)
assigner = LayerDecayValueAssigner(
list(
args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)
),
depths=convnext.depths,
layer_decay_type=args.layer_decay_type,
)
skip = {}
if hasattr(model.encoder, "no_weight_decay"):
skip = model.encoder.no_weight_decay()
backbone_param_groups = get_parameter_groups(
model.encoder,
args.weight_decay,
skip,
assigner.get_layer_id,
assigner.get_scale,
)
decoder_param_groups = [
{"params": model.head.parameters(), "weight_decay": 0.0, "lr": args.lr}
]
optimizer = torch.optim.AdamW(
backbone_param_groups + decoder_param_groups, lr=args.lr
)
else:
raise NotImplementedError(f"unknown model: {args.model}")
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model))
print("number of params (M): %.2f" % (n_parameters / 1.0e6))
return model, optimizer
def vit_backbone_parameters(
model: torch.nn.Module, weight_decay=1e-5, no_weight_decay_list=(), lr=1e-3
):
no_weight_decay_list = set(no_weight_decay_list)
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
else:
decay.append(param)
return [
{"params": no_decay, "weight_decay": 0.0, "lr": lr},
{"params": decay, "weight_decay": weight_decay, "lr": lr},
]