File size: 1,987 Bytes
ba23d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from sapiens.registry import SCHEDULERS
from torch.optim.lr_scheduler import (
    _LRScheduler,
    ConstantLR,
    CosineAnnealingLR,
    ExponentialLR,
    LinearLR,
    MultiStepLR,
    PolynomialLR,
    SequentialLR as _SequentialLR,
    StepLR,
)

SCHEDULERS.register_module(name="LinearLR")(LinearLR)
SCHEDULERS.register_module(name="PolynomialLR")(PolynomialLR)
SCHEDULERS.register_module(name="CosineAnnealingLR")(CosineAnnealingLR)
SCHEDULERS.register_module(name="ConstantLR")(ConstantLR)
SCHEDULERS.register_module(name="StepLR")(StepLR)
SCHEDULERS.register_module(name="MultiStepLR")(MultiStepLR)
SCHEDULERS.register_module(name="ExponentialLR")(ExponentialLR)


# ------------------------------------------------------------------------- #
@SCHEDULERS.register_module(name="SequentialLR")
class SequentialLR(_SequentialLR):
    """SequentialLR that accepts inner schedulers as config dicts.

    Example (iteration based):

    ```python
    warmup_iters = 400
    param_scheduler = dict(
        type="SequentialLR",
        milestones=[warmup_iters],
        schedulers=[
            dict(type="LinearLR",     start_factor=1e-3,
                 total_iters=warmup_iters),
            dict(type="PolynomialLR", total_iters=num_iters-warmup_iters,
                 power=1.0),
        ],
    )
    ```
    """

    def __init__(
        self,
        optimizer,
        schedulers,
        milestones,
        last_epoch: int = -1,
    ):
        built = [
            s
            if isinstance(s, _LRScheduler)
            else SCHEDULERS.build(s, optimizer=optimizer)
            for s in schedulers
        ]
        super().__init__(
            optimizer,
            schedulers=built,
            milestones=milestones,
            last_epoch=last_epoch,
        )