File size: 947 Bytes
0afe769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import math
from torch.optim.lr_scheduler import LambdaLR


def cos_scheduler_per_epoch(
    optimizer,
    *,
    epoch: int,
    num_epochs: int,
    steps_in_epoch: int,
    warmup_epochs: float = 0.0,
    num_cycles: float = 0.5,   
    min_lr_ratio: float = 0.1,   
):

    def lr_lambda(step_in_epoch: int):
        # global progress in [0, 1]
        p = (epoch + step_in_epoch / max(1, steps_in_epoch)) / max(1, num_epochs)

        # warmup phase
        if warmup_epochs > 0:
            p_w = warmup_epochs / max(1, num_epochs)
            if p < p_w:
                return p / p_w

        # cosine decay phase
        p_w = warmup_epochs / max(1, num_epochs)
        p_decay = (p - p_w) / max(1e-12, 1.0 - p_w)

        cos_val = 0.5 * (1.0 + math.cos(math.pi * 2.0 * num_cycles * p_decay))
        cos_val = max(0.0, cos_val)

        return min_lr_ratio + (1.0 - min_lr_ratio) * cos_val

    return LambdaLR(optimizer, lr_lambda)