| |
| import math |
| import os |
| import sys |
|
|
| from transformers import Trainer |
|
|
| from swift.trainers.optimizers.galore import create_optimizer_and_scheduler |
| from swift.utils import get_dist_setting |
|
|
|
|
| def calculate_max_steps(args: 'TrainArguments', dataset) -> int: |
| if args.max_steps and args.max_steps > 0: |
| max_steps = args.max_steps |
| else: |
| len_dataset = len(dataset) |
| _, _, world_size, _ = get_dist_setting() |
| total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size |
| num_update_steps_per_epoch = len_dataset // total_train_batch_size |
| num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) |
| max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) |
| return max_steps |
|
|
|
|
| def create_galore_optimizer(args, model, dataset): |
| training_steps = calculate_max_steps(args, dataset) |
| optimizer, lr_scheduler = create_optimizer_and_scheduler( |
| model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay) |
| |
| args.galore_config = None |
| return optimizer, lr_scheduler |
|
|
|
|
| def create_lorap_optimizer(args, model, dataset): |
| optimizer_grouped_parameters = None |
| if hasattr(model, 'create_optimizer_param_groups'): |
| |
| optimizer_grouped_parameters = model.create_optimizer_param_groups( |
| lr=args.learning_rate, weight_decay=args.weight_decay) |
|
|
| if optimizer_grouped_parameters is None: |
| |
| decay_parameters = Trainer.get_decay_parameter_names(None, model) |
| optimizer_grouped_parameters = [ |
| { |
| 'params': [p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)], |
| 'weight_decay': args.weight_decay, |
| }, |
| { |
| 'params': [p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)], |
| 'weight_decay': 0.0, |
| }, |
| ] |
| optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) |
| return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None |
|
|
|
|
| def create_muon_optimizer(args, model, dataset): |
| from swift.llm import git_clone_github, get_model_arch |
| if not args.local_repo_path: |
| args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git') |
| sys.path.append(os.path.join(args.local_repo_path, 'examples')) |
| from toy_train import Muon |
|
|
| |
| optim_args = {} |
| if args.optim_args: |
| for mapping in args.optim_args.replace(' ', '').split(','): |
| key, value = mapping.split('=') |
| optim_args[key] = value |
|
|
| model_arch = get_model_arch(model.model_meta.model_arch) |
| embed_key = model_arch.embedding or 'embed_tokens' |
| lm_head_key = model_arch.lm_head or 'lm_head' |
| muon_params = [ |
| p for n, p in model.named_parameters() |
| if p.requires_grad and p.ndim >= 2 and embed_key not in n and lm_head_key not in n |
| ] |
| adamw_params = [ |
| p for n, p in model.named_parameters() |
| if p.requires_grad and not (p.ndim >= 2 and embed_key not in n and lm_head_key not in n) |
| ] |
|
|
| return Muon( |
| lr=args.learning_rate, |
| wd=args.weight_decay, |
| muon_params=muon_params, |
| adamw_params=adamw_params, |
| adamw_betas=(args.adam_beta1, args.adam_beta2), |
| adamw_eps=args.adam_epsilon, |
| **optim_args, |
| ), None |
|
|
|
|
| |
| optimizers_map = { |
| 'galore': create_galore_optimizer, |
| 'lorap': create_lorap_optimizer, |
| 'muon': create_muon_optimizer, |
| } |
|
|