| import torch |
| import numpy as np |
|
|
| def create_optimizer(model, config): |
| train_config = config['training'] |
| base_lr = train_config['learning_rate'] |
| weight_decay = train_config['weight_decay'] |
| |
| layer_decay = train_config.get('layer_decay', 0.8) |
| |
| |
| |
| num_layers = len(model.blocks) + 1 |
| |
| parameter_groups = [] |
| |
| |
| head_lr = train_config.get('head_lr', base_lr) |
| parameter_groups.append({ |
| "params": [p for n, p in model.named_parameters() if "head" in n], |
| "lr": head_lr, |
| "weight_decay": weight_decay |
| }) |
|
|
| |
| for i, block in enumerate(model.blocks): |
| |
| |
| |
| scale = layer_decay ** (num_layers - i - 1) |
| |
| parameter_groups.append({ |
| "params": block.parameters(), |
| "lr": base_lr * scale, |
| "weight_decay": weight_decay |
| }) |
|
|
| |
| earliest_params = [] |
| for n, p in model.named_parameters(): |
| if "patch_embed" in n or "encoder_norm" in n: |
| earliest_params.append(p) |
| |
| if earliest_params: |
| parameter_groups.append({ |
| "params": earliest_params, |
| "lr": base_lr * (layer_decay ** num_layers), |
| "weight_decay": weight_decay |
| }) |
|
|
| if train_config['optimizer'].lower() == 'adamw': |
| optimizer = torch.optim.AdamW( |
| parameter_groups, |
| betas=tuple(train_config['betas']), |
| weight_decay=train_config['weight_decay'] |
| ) |
| elif train_config['optimizer'].lower() == 'sgd': |
| optimizer = torch.optim.SGD( |
| parameter_groups, |
| momentum=train_config.get('momentum', 0.9), |
| weight_decay=train_config['weight_decay'] |
| ) |
| else: |
| raise ValueError(f"Unsupported optimizer: {train_config['optimizer']}") |
|
|
| return optimizer |
|
|
|
|
| def create_lr_scheduler(optimizer, config, steps_per_epoch): |
| """Create learning rate scheduler""" |
| train_config = config['training'] |
| total_steps = train_config['epochs'] * steps_per_epoch |
| warmup_steps = train_config['warmup_epochs'] * steps_per_epoch |
|
|
| if train_config['lr_scheduler'].lower() == 'cosine': |
| def lr_lambda(current_step): |
| if current_step < warmup_steps: |
| |
| return float(current_step) / float(max(1, warmup_steps)) |
| else: |
| |
| progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) |
| return max(train_config['min_lr'] / train_config['learning_rate'], |
| 0.5 * (1.0 + np.cos(np.pi * progress))) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| else: |
| raise ValueError(f"Unsupported scheduler: {train_config['lr_scheduler']}") |
|
|
| return scheduler |