| import torch |
| import torch.optim as optim |
| import numpy as np |
| import itertools |
|
|
| def singleton(class_): |
| instances = {} |
| def getinstance(*args, **kwargs): |
| if class_ not in instances: |
| instances[class_] = class_(*args, **kwargs) |
| return instances[class_] |
| return getinstance |
|
|
| class get_optimizer(object): |
| def __init__(self): |
| self.optimizer = {} |
| self.register(optim.SGD, 'sgd') |
| self.register(optim.Adam, 'adam') |
| self.register(optim.AdamW, 'adamw') |
|
|
| def register(self, optim, name): |
| self.optimizer[name] = optim |
|
|
| def __call__(self, net, cfg): |
| if cfg is None: |
| return None |
| t = cfg.type |
| if isinstance(net, (torch.nn.DataParallel, |
| torch.nn.parallel.DistributedDataParallel)): |
| netm = net.module |
| else: |
| netm = net |
| pg = getattr(netm, 'parameter_group', None) |
|
|
| if pg is not None: |
| params = [] |
| for group_name, module_or_para in pg.items(): |
| if not isinstance(module_or_para, list): |
| module_or_para = [module_or_para] |
|
|
| grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para] |
| grouped_params = itertools.chain(*grouped_params) |
| pg_dict = {'params':grouped_params, 'name':group_name} |
| params.append(pg_dict) |
| else: |
| params = net.parameters() |
| return self.optimizer[t](params, lr=0, **cfg.args) |
|
|