GibbsTTS / utils /scheduler.py
ydqmkkx's picture
update
0afe769
raw
history blame contribute delete
947 Bytes
import math
from torch.optim.lr_scheduler import LambdaLR
def cos_scheduler_per_epoch(
optimizer,
*,
epoch: int,
num_epochs: int,
steps_in_epoch: int,
warmup_epochs: float = 0.0,
num_cycles: float = 0.5,
min_lr_ratio: float = 0.1,
):
def lr_lambda(step_in_epoch: int):
# global progress in [0, 1]
p = (epoch + step_in_epoch / max(1, steps_in_epoch)) / max(1, num_epochs)
# warmup phase
if warmup_epochs > 0:
p_w = warmup_epochs / max(1, num_epochs)
if p < p_w:
return p / p_w
# cosine decay phase
p_w = warmup_epochs / max(1, num_epochs)
p_decay = (p - p_w) / max(1e-12, 1.0 - p_w)
cos_val = 0.5 * (1.0 + math.cos(math.pi * 2.0 * num_cycles * p_decay))
cos_val = max(0.0, cos_val)
return min_lr_ratio + (1.0 - min_lr_ratio) * cos_val
return LambdaLR(optimizer, lr_lambda)