Spaces:
Running
Running
File size: 1,987 Bytes
ba23d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from sapiens.registry import SCHEDULERS
from torch.optim.lr_scheduler import (
_LRScheduler,
ConstantLR,
CosineAnnealingLR,
ExponentialLR,
LinearLR,
MultiStepLR,
PolynomialLR,
SequentialLR as _SequentialLR,
StepLR,
)
SCHEDULERS.register_module(name="LinearLR")(LinearLR)
SCHEDULERS.register_module(name="PolynomialLR")(PolynomialLR)
SCHEDULERS.register_module(name="CosineAnnealingLR")(CosineAnnealingLR)
SCHEDULERS.register_module(name="ConstantLR")(ConstantLR)
SCHEDULERS.register_module(name="StepLR")(StepLR)
SCHEDULERS.register_module(name="MultiStepLR")(MultiStepLR)
SCHEDULERS.register_module(name="ExponentialLR")(ExponentialLR)
# ------------------------------------------------------------------------- #
@SCHEDULERS.register_module(name="SequentialLR")
class SequentialLR(_SequentialLR):
"""SequentialLR that accepts inner schedulers as config dicts.
Example (iteration based):
```python
warmup_iters = 400
param_scheduler = dict(
type="SequentialLR",
milestones=[warmup_iters],
schedulers=[
dict(type="LinearLR", start_factor=1e-3,
total_iters=warmup_iters),
dict(type="PolynomialLR", total_iters=num_iters-warmup_iters,
power=1.0),
],
)
```
"""
def __init__(
self,
optimizer,
schedulers,
milestones,
last_epoch: int = -1,
):
built = [
s
if isinstance(s, _LRScheduler)
else SCHEDULERS.build(s, optimizer=optimizer)
for s in schedulers
]
super().__init__(
optimizer,
schedulers=built,
milestones=milestones,
last_epoch=last_epoch,
)
|