| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| import copy |
| import paddle |
|
|
| __all__ = ["build_optimizer"] |
|
|
|
|
| def build_lr_scheduler(lr_config, epochs, step_each_epoch): |
| from . import learning_rate |
|
|
| lr_config.update({"epochs": epochs, "step_each_epoch": step_each_epoch}) |
| lr_name = lr_config.pop("name", "Const") |
| lr = getattr(learning_rate, lr_name)(**lr_config)() |
| return lr |
|
|
|
|
| def build_optimizer(config, epochs, step_each_epoch, model): |
| from . import regularizer, optimizer |
|
|
| config = copy.deepcopy(config) |
| |
| lr = build_lr_scheduler(config.pop("lr"), epochs, step_each_epoch) |
|
|
| |
| if "regularizer" in config and config["regularizer"] is not None: |
| reg_config = config.pop("regularizer") |
| reg_name = reg_config.pop("name") |
| if not hasattr(regularizer, reg_name): |
| reg_name += "Decay" |
| reg = getattr(regularizer, reg_name)(**reg_config)() |
| elif "weight_decay" in config: |
| reg = config.pop("weight_decay") |
| else: |
| reg = None |
|
|
| |
| optim_name = config.pop("name") |
| if "clip_norm" in config: |
| clip_norm = config.pop("clip_norm") |
| grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm) |
| elif "clip_norm_global" in config: |
| clip_norm = config.pop("clip_norm_global") |
| grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm) |
| else: |
| grad_clip = None |
| optim = getattr(optimizer, optim_name)( |
| learning_rate=lr, weight_decay=reg, grad_clip=grad_clip, **config |
| ) |
| return optim(model), lr |
|
|