| import torch |
|
|
|
|
| class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): |
| """ Implements the learning rate schedule defined in the AlphaFold 2 |
| supplement. A linear warmup is followed by a plateau at the maximum |
| learning rate and then exponential decay. |
| |
| Note that the initial learning rate of the optimizer in question is |
| ignored; use this class' base_lr parameter to specify the starting |
| point of the warmup. |
| """ |
| def __init__(self, |
| optimizer, |
| last_epoch: int = -1, |
| verbose: bool = False, |
| base_lr: float = 0., |
| max_lr: float = 0.001, |
| warmup_no_steps: int = 1000, |
| start_decay_after_n_steps: int = 50000, |
| decay_every_n_steps: int = 50000, |
| decay_factor: float = 0.95, |
| ): |
| step_counts = { |
| "warmup_no_steps": warmup_no_steps, |
| "start_decay_after_n_steps": start_decay_after_n_steps, |
| } |
|
|
| for k,v in step_counts.items(): |
| if(v < 0): |
| raise ValueError(f"{k} must be nonnegative") |
|
|
| if(warmup_no_steps > start_decay_after_n_steps): |
| raise ValueError( |
| "warmup_no_steps must not exceed start_decay_after_n_steps" |
| ) |
|
|
| self.optimizer = optimizer |
| self.last_epoch = last_epoch |
| self.verbose = verbose |
| self.base_lr = base_lr |
| self.max_lr = max_lr |
| self.warmup_no_steps = warmup_no_steps |
| self.start_decay_after_n_steps = start_decay_after_n_steps |
| self.decay_every_n_steps = decay_every_n_steps |
| self.decay_factor = decay_factor |
|
|
| super(AlphaFoldLRScheduler, self).__init__( |
| optimizer, |
| last_epoch=last_epoch, |
| verbose=verbose, |
| ) |
|
|
| def state_dict(self): |
| state_dict = { |
| k:v for k,v in self.__dict__.items() if k not in ["optimizer"] |
| } |
|
|
| return state_dict |
|
|
| def load_state_dict(self, state_dict): |
| self.__dict__.update(state_dict) |
|
|
| def get_lr(self): |
| if(not self._get_lr_called_within_step): |
| raise RuntimeError( |
| "To get the last learning rate computed by the scheduler, use " |
| "get_last_lr()" |
| ) |
|
|
| step_no = self.last_epoch |
|
|
| if(step_no <= self.warmup_no_steps): |
| lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr |
| elif(step_no > self.start_decay_after_n_steps): |
| steps_since_decay = step_no - self.start_decay_after_n_steps |
| exp = (steps_since_decay // self.decay_every_n_steps) + 1 |
| lr = self.max_lr * (self.decay_factor ** exp) |
| else: |
| lr = self.max_lr |
|
|
| return [lr for group in self.optimizer.param_groups] |
|
|
|
|
| class TestAF2LRScheduler(AlphaFoldLRScheduler): |
| def __init__(self, |
| optimizer, |
| last_epoch: int = -1, |
| verbose: bool = False, |
| base_lr: float = 0., |
| max_lr: float = 0.0001, |
| warmup_no_steps: int = 10, |
| start_decay_after_n_steps: int = 100, |
| decay_every_n_steps: int = 10, |
| decay_factor: float = 0.95, |
| ): |
| super().__init__( |
| optimizer, |
| last_epoch, |
| verbose, |
| base_lr, |
| max_lr, |
| warmup_no_steps, |
| start_decay_after_n_steps, |
| decay_every_n_steps, |
| decay_factor, |
| ) |