| from functools import partial |
| from typing import Callable |
|
|
|
|
| def linear_warm_up( |
| step: int, |
| warm_up_steps: int, |
| reduce_lr_steps: int |
| ) -> float: |
| r"""Get linear warm up scheduler for LambdaLR. |
| |
| Args: |
| step (int): global step |
| warm_up_steps (int): steps for warm up |
| reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step |
| |
| .. code-block: python |
| >>> lr_lambda = partial(linear_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) |
| >>> from torch.optim.lr_scheduler import LambdaLR |
| >>> LambdaLR(optimizer, lr_lambda) |
| |
| Returns: |
| lr_scale (float): learning rate scaler |
| """ |
|
|
| if step <= warm_up_steps: |
| lr_scale = step / warm_up_steps |
| else: |
| lr_scale = 0.9 ** (step // reduce_lr_steps) |
|
|
| return lr_scale |
|
|
|
|
| def constant_warm_up( |
| step: int, |
| warm_up_steps: int, |
| reduce_lr_steps: int |
| ) -> float: |
| r"""Get constant warm up scheduler for LambdaLR. |
| |
| Args: |
| step (int): global step |
| warm_up_steps (int): steps for warm up |
| reduce_lr_steps (int): reduce learning rate by a factor of 0.9 #reduce_lr_steps step |
| |
| .. code-block: python |
| >>> lr_lambda = partial(constant_warm_up, warm_up_steps=1000, reduce_lr_steps=10000) |
| >>> from torch.optim.lr_scheduler import LambdaLR |
| >>> LambdaLR(optimizer, lr_lambda) |
| |
| Returns: |
| lr_scale (float): learning rate scaler |
| """ |
| |
| if 0 <= step < warm_up_steps: |
| lr_scale = 0.001 |
|
|
| elif warm_up_steps <= step < 2 * warm_up_steps: |
| lr_scale = 0.01 |
|
|
| elif 2 * warm_up_steps <= step < 3 * warm_up_steps: |
| lr_scale = 0.1 |
|
|
| else: |
| lr_scale = 1 |
|
|
| return lr_scale |
|
|
|
|
| def get_lr_lambda( |
| lr_lambda_type: str, |
| **kwargs |
| ) -> Callable: |
| r"""Get learning scheduler. |
| |
| Args: |
| lr_lambda_type (str), e.g., "constant_warm_up" | "linear_warm_up" |
| |
| Returns: |
| lr_lambda_func (Callable) |
| """ |
| if lr_lambda_type == "constant_warm_up": |
|
|
| lr_lambda_func = partial( |
| constant_warm_up, |
| warm_up_steps=kwargs["warm_up_steps"], |
| reduce_lr_steps=kwargs["reduce_lr_steps"], |
| ) |
|
|
| elif lr_lambda_type == "linear_warm_up": |
|
|
| lr_lambda_func = partial( |
| linear_warm_up, |
| warm_up_steps=kwargs["warm_up_steps"], |
| reduce_lr_steps=kwargs["reduce_lr_steps"], |
| ) |
|
|
| else: |
| raise NotImplementedError |
|
|
| return lr_lambda_func |
|
|