| |
|
|
| import math |
| from enum import Enum |
| from functools import partial |
|
|
| from torch.optim import Optimizer |
| from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
| class SchedulerType(Enum): |
| """Enumeration of available learning rate schedulers.""" |
|
|
| COSINE = "cosine" |
| COSINE_WITH_WARMUP = "cosine_with_warmup" |
| COSINE_WITH_RESTARTS = "cosine_with_restarts" |
| WARMUP_STABLE_DECAY = "warmup_stable_decay" |
| POLYNOMIAL_WITH_WARMUP = "polynomial_with_warmup" |
| LINEAR_WITH_WARMUP = "linear_with_warmup" |
| CONSTANT_WITH_WARMUP = "constant_with_warmup" |
| INVERSE_SQRT = "inverse_sqrt" |
|
|
|
|
| def _get_warmup_stable_decay_lr_lambda( |
| current_step: int, |
| *, |
| num_warmup_steps: int, |
| num_stable_steps: int, |
| num_training_steps: int, |
| min_lr_ratio: float = 0.001, |
| decay_type: str = "cosine", |
| ): |
| """ |
| Learning rate lambda function for Warmup-Stable-Decay (WSD) schedule. |
| |
| This scheduler implements three phases: |
| 1. Warmup: Linear increase from 0 to peak learning rate |
| 2. Stable: Constant learning rate for majority of training |
| 3. Decay: Gradual decrease using cosine or linear decay |
| |
| Args: |
| current_step: Current training step |
| num_warmup_steps: Number of warmup steps |
| num_stable_steps: Number of stable learning rate steps |
| num_training_steps: Total number of training steps |
| min_lr_ratio: Minimum learning rate as ratio of peak learning rate |
| decay_type: Type of decay schedule ("cosine" or "linear") |
| """ |
| if current_step < num_warmup_steps: |
| |
| return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
| elif current_step < num_warmup_steps + num_stable_steps: |
| |
| return 1.0 |
|
|
| else: |
| |
| decay_steps = num_training_steps - num_warmup_steps - num_stable_steps |
| if decay_steps <= 0: |
| return max(min_lr_ratio, 1.0) |
|
|
| progress = (current_step - num_warmup_steps - num_stable_steps) / decay_steps |
| progress = min(progress, 1.0) |
|
|
| if decay_type == "cosine": |
| |
| decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| return max(min_lr_ratio, decay_factor) |
| elif decay_type == "linear": |
| |
| decay_factor = 1.0 - progress |
| return max(min_lr_ratio, decay_factor) |
| else: |
| raise ValueError(f"Unknown decay_type: {decay_type}") |
|
|
|
|
| def get_warmup_stable_decay_schedule( |
| optimizer: Optimizer, |
| num_warmup_steps: int, |
| num_stable_steps: int, |
| num_training_steps: int, |
| min_lr_ratio: float = 0.01, |
| decay_type: str = "cosine", |
| last_epoch: int = -1, |
| ): |
| """ |
| Create a Warmup-Stable-Decay learning rate schedule. |
| |
| This scheduler is particularly well-suited for foundation model training as it: |
| - Provides stable learning during the majority of training |
| - Doesn't require pre-committing to exact training duration |
| - Allows for extended training without aggressive decay |
| |
| Args: |
| optimizer: The optimizer for which to schedule the learning rate |
| num_warmup_steps: Number of steps for warmup phase |
| num_stable_steps: Number of steps for stable learning rate phase |
| num_training_steps: Total number of training steps |
| min_lr_ratio: Minimum learning rate as fraction of peak learning rate |
| decay_type: Type of decay ("cosine" or "linear") |
| last_epoch: The index of the last epoch when resuming training |
| |
| Returns: |
| torch.optim.lr_scheduler.LambdaLR with the WSD schedule |
| """ |
| lr_lambda = partial( |
| _get_warmup_stable_decay_lr_lambda, |
| num_warmup_steps=num_warmup_steps, |
| num_stable_steps=num_stable_steps, |
| num_training_steps=num_training_steps, |
| min_lr_ratio=min_lr_ratio, |
| decay_type=decay_type, |
| ) |
| return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
| def _get_cosine_schedule_with_warmup_lr_lambda( |
| current_step: int, |
| *, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: float = 0.5, |
| min_lr_ratio: float = 0.0, |
| ): |
| """Enhanced cosine schedule with configurable minimum learning rate.""" |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| cosine_factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
| return max(min_lr_ratio, cosine_factor) |
|
|
|
|
| def get_enhanced_cosine_schedule_with_warmup( |
| optimizer: Optimizer, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: float = 0.5, |
| min_lr_ratio: float = 0.01, |
| last_epoch: int = -1, |
| ): |
| """ |
| Enhanced cosine schedule with warmup and configurable minimum learning rate. |
| |
| Args: |
| optimizer: The optimizer for which to schedule the learning rate |
| num_warmup_steps: Number of steps for warmup phase |
| num_training_steps: Total number of training steps |
| num_cycles: Number of cosine cycles (0.5 = half cosine) |
| min_lr_ratio: Minimum learning rate as fraction of peak learning rate |
| last_epoch: The index of the last epoch when resuming training |
| """ |
| lr_lambda = partial( |
| _get_cosine_schedule_with_warmup_lr_lambda, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| num_cycles=num_cycles, |
| min_lr_ratio=min_lr_ratio, |
| ) |
| return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
| def _get_cosine_with_restarts_lr_lambda( |
| current_step: int, |
| *, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: int = 1, |
| min_lr_ratio: float = 0.0, |
| ): |
| """Cosine schedule with hard restarts and configurable minimum learning rate.""" |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
|
|
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| if progress >= 1.0: |
| return min_lr_ratio |
|
|
| cosine_factor = 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))) |
| return max(min_lr_ratio, cosine_factor) |
|
|
|
|
| def get_cosine_with_restarts_schedule( |
| optimizer: Optimizer, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: int = 4, |
| min_lr_ratio: float = 0.01, |
| last_epoch: int = -1, |
| ): |
| """ |
| Cosine schedule with hard restarts. |
| |
| Args: |
| optimizer: The optimizer for which to schedule the learning rate |
| num_warmup_steps: Number of steps for warmup phase |
| num_training_steps: Total number of training steps |
| num_cycles: Number of restart cycles |
| min_lr_ratio: Minimum learning rate as fraction of peak learning rate |
| last_epoch: The index of the last epoch when resuming training |
| """ |
| lr_lambda = partial( |
| _get_cosine_with_restarts_lr_lambda, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| num_cycles=num_cycles, |
| min_lr_ratio=min_lr_ratio, |
| ) |
| return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
| |
| SCHEDULER_REGISTRY = { |
| SchedulerType.WARMUP_STABLE_DECAY: get_warmup_stable_decay_schedule, |
| SchedulerType.COSINE_WITH_WARMUP: get_enhanced_cosine_schedule_with_warmup, |
| SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_restarts_schedule, |
| } |
|
|
|
|
| def get_scheduler( |
| scheduler_type: str | SchedulerType, |
| optimizer: Optimizer, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| scheduler_kwargs: dict | None = None, |
| ): |
| """ |
| Unified interface to create learning rate schedulers. |
| |
| Args: |
| scheduler_type: Type of scheduler to create |
| optimizer: The optimizer to schedule |
| num_warmup_steps: Number of warmup steps |
| num_training_steps: Total training steps |
| scheduler_kwargs: Additional scheduler-specific parameters |
| |
| Returns: |
| Configured learning rate scheduler |
| """ |
| if isinstance(scheduler_type, str): |
| scheduler_type = SchedulerType(scheduler_type) |
|
|
| if scheduler_kwargs is None: |
| scheduler_kwargs = {} |
|
|
| if scheduler_type not in SCHEDULER_REGISTRY: |
| raise ValueError(f"Unsupported scheduler type: {scheduler_type}") |
|
|
| scheduler_func = SCHEDULER_REGISTRY[scheduler_type] |
| return scheduler_func( |
| optimizer=optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| **scheduler_kwargs, |
| ) |
|
|
|
|
| class WarmupStableDecayScheduler: |
| """ |
| Alternative implementation as a standalone scheduler class. |
| |
| This provides more flexibility and better state management for |
| complex training scenarios with checkpointing. |
| """ |
|
|
| def __init__( |
| self, |
| optimizer: Optimizer, |
| num_warmup_steps: int, |
| num_stable_steps: int, |
| total_steps: int, |
| min_lr_ratio: float = 0.01, |
| decay_type: str = "cosine", |
| verbose: bool = False, |
| ): |
| self.optimizer = optimizer |
| self.num_warmup_steps = num_warmup_steps |
| self.num_stable_steps = num_stable_steps |
| self.total_steps = total_steps |
| self.min_lr_ratio = min_lr_ratio |
| self.decay_type = decay_type |
| self.verbose = verbose |
|
|
| |
| self.base_lrs = [group["lr"] for group in optimizer.param_groups] |
| self.current_step = 0 |
|
|
| def get_lr_factor(self, step: int) -> float: |
| """Calculate the learning rate multiplication factor for given step.""" |
| if step < self.num_warmup_steps: |
| |
| return step / max(1, self.num_warmup_steps) |
| elif step < self.num_warmup_steps + self.num_stable_steps: |
| |
| return 1.0 |
| else: |
| |
| decay_steps = self.total_steps - self.num_warmup_steps - self.num_stable_steps |
| if decay_steps <= 0: |
| return max(self.min_lr_ratio, 1.0) |
|
|
| progress = (step - self.num_warmup_steps - self.num_stable_steps) / decay_steps |
| progress = min(progress, 1.0) |
|
|
| if self.decay_type == "cosine": |
| decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| elif self.decay_type == "linear": |
| decay_factor = 1.0 - progress |
| else: |
| raise ValueError(f"Unknown decay_type: {self.decay_type}") |
|
|
| return max(self.min_lr_ratio, decay_factor) |
|
|
| def step(self): |
| """Update learning rates for all parameter groups.""" |
| lr_factor = self.get_lr_factor(self.current_step) |
|
|
| for param_group, base_lr in zip(self.optimizer.param_groups, self.base_lrs, strict=True): |
| param_group["lr"] = base_lr * lr_factor |
|
|
| if self.verbose and self.current_step % 1000 == 0: |
| phase = self.get_phase() |
| print(f"Step {self.current_step}: LR factor = {lr_factor:.6f}, Phase = {phase}") |
|
|
| self.current_step += 1 |
|
|
| def get_phase(self) -> str: |
| """Get current training phase.""" |
| if self.current_step < self.num_warmup_steps: |
| return "warmup" |
| elif self.current_step < self.num_warmup_steps + self.num_stable_steps: |
| return "stable" |
| else: |
| return "decay" |
|
|
| def state_dict(self) -> dict: |
| """Return scheduler state for checkpointing.""" |
| return { |
| "current_step": self.current_step, |
| "base_lrs": self.base_lrs, |
| } |
|
|
| def load_state_dict(self, state_dict: dict): |
| """Load scheduler state from checkpoint.""" |
| self.current_step = state_dict["current_step"] |
| self.base_lrs = state_dict["base_lrs"] |
|
|