VibeToken / utils /lr_schedulers.py
APGASU's picture
scripts
7bef20f verified
"""Learning rate schedulers.
Reference:
https://raw.githubusercontent.com/huggingface/open-muse/vqgan-finetuning/muse/lr_schedulers.py
"""
import math
from enum import Enum
from typing import Optional, Union
import torch
class SchedulerType(Enum):
COSINE = "cosine"
CONSTANT = "constant"
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
base_lr: float = 1e-4,
end_lr: float = 0.0,
):
"""Creates a cosine learning rate schedule with warm-up and ending learning rate.
Args:
optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
num_warmup_steps: An integer, the number of steps for the warmup phase.
num_training_steps: An integer, the total number of training steps.
num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
just decrease from the max value to 0 following a half-cosine).
last_epoch: An integer, the index of the last epoch when resuming training.
base_lr: A float, the base learning rate.
end_lr: A float, the final learning rate.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / \
float(max(1, num_training_steps - num_warmup_steps))
ratio = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return (end_lr + (base_lr - end_lr) * ratio) / base_lr
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
def get_constant_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
base_lr: float = 1e-4,
end_lr: float = 0.0,
):
"""UViT: Creates a constant learning rate schedule with warm-up.
Args:
optimizer: A torch.optim.Optimizer, the optimizer for which to schedule the learning rate.
num_warmup_steps: An integer, the number of steps for the warmup phase.
num_training_steps: An integer, the total number of training steps.
num_cycles : A float, the number of periods of the cosine function in a schedule (the default is to
just decrease from the max value to 0 following a half-cosine).
last_epoch: An integer, the index of the last epoch when resuming training.
base_lr: A float, the base learning rate.
end_lr: A float, the final learning rate.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
else:
return 1.0
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
SchedulerType.CONSTANT: get_constant_schedule_with_warmup,
}
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: torch.optim.Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
base_lr: float = 1e-4,
end_lr: float = 0.0,
):
"""Retrieves a learning rate scheduler from the given name and optimizer.
Args:
name: A string or SchedulerType, the name of the scheduler to retrieve.
optimizer: torch.optim.Optimizer. The optimizer to use with the scheduler.
num_warmup_steps: An integer, the number of warmup steps.
num_training_steps: An integer, the total number of training steps.
base_lr: A float, the base learning rate.
end_lr: A float, the final learning rate.
Returns:
A instance of torch.optim.lr_scheduler.LambdaLR
Raises:
ValueError: If num_warmup_steps or num_training_steps is not provided.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
base_lr=base_lr,
end_lr=end_lr,
)