| """Learning rate schedulers. |
| |
| Reference: |
| https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py |
| """ |
| import math |
| from enum import Enum |
| from typing import Optional, Union |
|
|
| import torch |
|
|
|
|
| class SchedulerType(Enum): |
| COSINE = "cosine" |
| CONSTANT = "constant" |
|
|
| def get_cosine_schedule_with_warmup( |
| optimizer: torch.optim.Optimizer, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: float = 0.5, |
| last_epoch: int = -1, |
| base_lr: float = 1e-4, |
| end_lr: float = 0.0, |
| ): |
| """Creates a cosine learning rate schedule with warm-up and ending learning rate. |
| |
| Args: |
| optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate. |
| num_warmup_steps: An integer, the number of steps for the warmup phase. |
| num_training_steps: An integer, the total number of training steps. |
| num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to |
| just decrease from the max value to 0 following a half-cosine). |
| last_epoch: An integer, the index of the last epoch when resuming training. |
| base_lr: A float, the base learning rate. |
| end_lr: A float, the final learning rate. |
| |
| Return: |
| `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| """ |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| progress = float(current_step - num_warmup_steps) / \ |
| float(max(1, num_training_steps - num_warmup_steps)) |
| ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
| return (end_lr + (base_lr - end_lr) * ratio) / base_lr |
|
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) |
|
|
|
|
| def get_constant_schedule_with_warmup( |
| optimizer: torch.optim.Optimizer, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| base_lr: float = 1e-4, |
| end_lr: float = 0.0, |
| ): |
| """UViT: Creates a constant learning rate schedule with warm-up. |
| |
| Args: |
| optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate. |
| num_warmup_steps: An integer, the number of steps for the warmup phase. |
| num_training_steps: An integer, the total number of training steps. |
| num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to |
| just decrease from the max value to 0 following a half-cosine). |
| last_epoch: An integer, the index of the last epoch when resuming training. |
| base_lr: A float, the base learning rate. |
| end_lr: A float, the final learning rate. |
| |
| Return: |
| `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. |
| """ |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| else: |
| return 1.0 |
|
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| TYPE_TO_SCHEDULER_FUNCTION = { |
| SchedulerType.COSINE: get_cosine_schedule_with_warmup, |
| SchedulerType.CONSTANT: get_constant_schedule_with_warmup, |
| } |
|
|
| def get_scheduler( |
| name: Union[str, SchedulerType], |
| optimizer: torch.optim.Optimizer, |
| num_warmup_steps: Optional[int] = None, |
| num_training_steps: Optional[int] = None, |
| base_lr: float = 1e-4, |
| end_lr: float = 0.0, |
| ): |
| """Retrieves a learning rate scheduler from the given name and optimizer. |
| |
| Args: |
| name: A string or SchedulerType, the name of the scheduler to retrieve. |
| optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler. |
| num_warmup_steps: An integer, the number of warmup steps. |
| num_training_steps: An integer, the total number of training steps. |
| base_lr: A float, the base learning rate. |
| end_lr: A float, the final learning rate. |
| |
| Returns: |
| A instance of torch.optim.lr_scheduler.LambdaLR |
| |
| Raises: |
| ValueError: If num_warmup_steps or num_training_steps is not provided. |
| """ |
| name = SchedulerType(name) |
| schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] |
|
|
| if num_warmup_steps is None: |
| raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") |
|
|
| if num_training_steps is None: |
| raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") |
|
|
| return schedule_func( |
| optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| base_lr=base_lr, |
| end_lr=end_lr, |
| ) |