| |
| |
|
|
| import socket |
| import time |
| import traceback |
| from functools import partial |
| import os |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| import internlm |
| from internlm.core.context import ParallelMode |
| from internlm.core.context import global_context as gpc |
| from internlm.core.scheduler import SchedulerMetricHook |
| from internlm.core.trainer import TrainState |
| from internlm.initialize import initialize_distributed_env |
| from internlm.model.loss import FlashGPTLMLoss, KLDivLoss |
| from internlm.model.metrics import AccPerplex |
| from internlm.monitor import initialize_monitor_manager, send_alert_message |
| from internlm.monitor.monitor import monitor_manager as mm |
| from internlm.train import ( |
| get_train_data_loader, |
| get_validation_data_loader, |
| initialize_llm_profile, |
| initialize_model, |
| initialize_teacher, |
| initialize_optimizer, |
| load_new_batch, |
| load_new_batch_stop, |
| record_current_batch_training_metrics, |
| ) |
| from internlm.utils.common import ( |
| BatchSkipper, |
| get_megatron_flops, |
| launch_time, |
| parse_args, |
| ) |
| from internlm.utils.evaluation import evaluate_on_val_dls |
| from internlm.utils.gputest import empty_cache_and_diag |
| from internlm.utils.logger import get_logger, initialize_uniscale_logger |
| from internlm.utils.megatron_timers import megatron_timer as timer |
| from internlm.utils.model_checkpoint import CheckpointManager, load_model_checkpoint |
| from internlm.utils.parallel import get_parallel_log_file_name |
| from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler |
| from internlm.utils.writer import Writer |
|
|
| |
| logger = get_logger(__file__) |
|
|
|
|
| def initialize_llm_logger(start_time: str): |
| """ |
| Initialize customed uniscale logger. |
| |
| Args: |
| start_time (str): The launch time of current training job. |
| |
| Returns: The instance of uniscale logger. |
| """ |
|
|
| uniscale_logger = initialize_uniscale_logger( |
| job_name=gpc.config.JOB_NAME, launch_time=start_time, file_name=get_parallel_log_file_name() |
| ) |
| if uniscale_logger is not None: |
| global logger |
| logger = uniscale_logger |
|
|
| return uniscale_logger |
|
|
|
|
| def main(args): |
| |
| skip_batches = gpc.config.data.skip_batches |
| total_steps = gpc.config.data.total_steps |
| valid_every = gpc.config.data.valid_every |
| label_smoothing = gpc.config.loss.label_smoothing |
|
|
| get_tflops_func = partial( |
| get_megatron_flops, |
| checkpoint=gpc.config.model.checkpoint, |
| seq_len=gpc.config.SEQ_LEN, |
| hidden_size=gpc.config.model.hidden_size, |
| num_layers=gpc.config.model.num_layers, |
| vocab_size=gpc.config.model.vocab_size, |
| global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), |
| global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), |
| mlp_ratio=gpc.config.MLP_RATIO, |
| ) |
|
|
| |
| current_time = launch_time() |
| objs = [current_time] |
| dist.broadcast_object_list(objs, src=0) |
| current_time = objs[0] |
|
|
| |
| uniscale_logger = initialize_llm_logger(start_time=current_time) |
|
|
| |
| model = initialize_model() |
| n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.warning(f'Model parameters: {n_parameters / 1e6} M.') |
|
|
| with open(args.config, "r") as f: |
| config_lines = f.readlines() |
|
|
| |
| criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) |
|
|
| |
| train_dl, dataset_types = get_train_data_loader(num_worker=4) |
| val_dls = get_validation_data_loader() |
|
|
| |
| train_state = TrainState(gpc.config, train_dl.batch_sampler) |
|
|
| optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) |
|
|
| ckpt_manager = CheckpointManager( |
| ckpt_config=gpc.config.ckpt, |
| model=model, |
| optimizer=optimizer, |
| lr_scheduler=lr_scheduler, |
| train_dl=train_dl, |
| model_config=gpc.config.model, |
| model_config_file="".join(config_lines), |
| feishu_address=gpc.config.monitor.alert.feishu_alert_address, |
| ) |
|
|
| |
| ckpt_manager.try_resume_training(train_state, current_time) |
|
|
| |
| writer = Writer( |
| job_name=gpc.config.JOB_NAME, |
| launch_time=current_time, |
| file_name=get_parallel_log_file_name(), |
| tensorboard_folder=gpc.config.tensorboard_folder, |
| resume_tb_folder=train_state.resume_tb_folder, |
| step_count=train_state.step_count, |
| config=config_lines, |
| logger=logger, |
| enable_tb=gpc.config.enable_tb, |
| ) |
|
|
| |
| metric = AccPerplex( |
| device=torch.cuda.current_device(), |
| tp_pg=gpc.get_group(ParallelMode.TENSOR), |
| dp_pg=gpc.get_group(ParallelMode.DATA), |
| dataset_types=dataset_types, |
| ) |
|
|
| |
| scheduler_hooks = [ |
| SchedulerMetricHook( |
| metric=metric, |
| skip=( |
| gpc.is_using_pp() |
| and hasattr(gpc.config.model, "num_chunks") |
| and gpc.config.model.num_chunks > 1 |
| and gpc.config.parallel["pipeline"].get("interleaved_overlap", False) |
| ), |
| ), |
| ] |
|
|
| if gpc.config.get('kd_config', None) is None: |
| trainer, train_dl, _, _ = internlm.initialize_trainer( |
| model=model, |
| optimizer=optimizer, |
| criterion=criterion, |
| train_dataloader=train_dl, |
| lr_scheduler=lr_scheduler, |
| beta2_scheduler=beta2_scheduler, |
| scheduler_hooks=scheduler_hooks, |
| ) |
| else: |
| |
| teacher = initialize_teacher() |
| n_parameters = sum(p.numel() for p in teacher.parameters()) |
| logger.warning(f'Teacher parameters: {n_parameters / 1e6} M.') |
|
|
| teacher.requires_grad_(False) |
| teacher.eval() |
|
|
| kd_criterion_type = gpc.config.kd_config.get('type', 'kl_div') |
| if kd_criterion_type == 'kl_div': |
| kd_criterion = KLDivLoss() |
| else: |
| raise NotImplementedError |
|
|
| trainer, train_dl, _, _ = internlm.initialize_kd_trainer( |
| model=model, |
| teacher=teacher, |
| optimizer=optimizer, |
| criterion=criterion, |
| kd_criterion=kd_criterion, |
| train_dataloader=train_dl, |
| lr_scheduler=lr_scheduler, |
| beta2_scheduler=beta2_scheduler, |
| scheduler_hooks=scheduler_hooks, |
| ) |
|
|
| |
| if args.profiling: |
| memory_profiler = SimpleMemoryProfiler( |
| model, |
| optimizer.optim, |
| log_folder=f"memory_trace/rank{gpc.get_global_rank()}_" |
| + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" |
| + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", |
| ) |
| else: |
| memory_profiler = None |
|
|
| |
| batch_skipper = BatchSkipper(skip_batches) |
|
|
| trainer.train() |
|
|
| |
| train_iter = iter(train_dl) |
|
|
| with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: |
| |
| for batch_count in range(train_state.batch_count, total_steps): |
| empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) |
| start_time = time.time() |
| timer("one-batch").start() |
|
|
| |
| |
| if gpc.config.data.train_one_epoch: |
| batch, train_iter = load_new_batch_stop(train_dl=train_dl, train_iter=train_iter, |
| train_state=train_state) |
| if batch is None: |
| now_break = ckpt_manager.try_save_checkpoint(train_state, data_iter_stop=True) |
| break |
| else: |
| batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) |
|
|
| |
| train_state.batch_count = batch_count |
| train_state.num_consumed_samples_in_epoch += len(batch[1]) |
| if batch_skipper(batch_count): |
| if gpc.is_rank_for_log(): |
| logger.info(f"Skip batch count:`{batch_count}`...") |
| timer("one-batch").stop() |
| continue |
|
|
| |
| trainer.zero_grad() |
| |
| if batch[0].get("type_ids", None) is not None: |
| metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) |
|
|
| |
| timer("fwd-bwd").start() |
|
|
| _, _, loss = trainer.execute_schedule( |
| batch, forward_only=False, return_loss=True, return_output_label=False |
| ) |
| timer("fwd-bwd").stop() |
|
|
| |
| trainer_result = trainer.step() |
| assert trainer_result is not None |
|
|
| success_update, grad_norm_groups = trainer_result |
| if success_update: |
| train_state.step_count += 1 |
| else: |
| train_state.inf_nan_skip_batches += 1 |
| if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): |
| logger.warning(f"Warning: skip parameter update at step {batch_count}.") |
| send_alert_message( |
| address=gpc.config.monitor.alert.feishu_alert_address, |
| message=f"Warning: skip parameter update at step {batch_count}.", |
| ) |
|
|
| |
| record_current_batch_training_metrics( |
| get_tflops_func=get_tflops_func, |
| logger=logger, |
| writer=writer, |
| success_update=success_update, |
| batch_count=batch_count, |
| batch=batch, |
| train_state=train_state, |
| optimizer=optimizer, |
| beta2_scheduler=beta2_scheduler, |
| trainer=trainer, |
| start_time=start_time, |
| loss=loss, |
| grad_norm=grad_norm_groups, |
| metric=metric, |
| update_panel=uniscale_logger is not None, |
| ) |
|
|
| timer("one-batch").stop() |
|
|
| |
| if valid_every > 0 and train_state.step_count % valid_every == 0: |
| evaluate_on_val_dls( |
| trainer=trainer, |
| val_dls=val_dls, |
| writer=writer, |
| logger=logger, |
| step_count=train_state.step_count, |
| update_panel=uniscale_logger is not None, |
| ) |
|
|
| |
| |
| now_break = ckpt_manager.try_save_checkpoint(train_state) |
| if now_break: |
| break |
|
|
| if memory_profiler is not None: |
| memory_profiler.step() |
|
|
| if batch_count % 2 == 0: |
| prof.step() |
|
|
| ckpt_manager.wait_async_upload_finish() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| hostname = socket.gethostname() |
|
|
| |
| initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) |
| assert hasattr(gpc, "config") and gpc.config is not None |
|
|
| |
| with initialize_monitor_manager( |
| job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address |
| ): |
| try: |
| main(args) |
| except Exception: |
| logger.error( |
| f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", |
| ) |
| mm.monitor_exception( |
| alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() |
| ) |
|
|