NeMo_Canary / scripts /performance /llm /mlperf_lora_llama2_70b.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nemo_run as run
import torch
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.common.tokenizers.huggingface import AutoTokenizer
from nemo.collections.llm.gpt.data import MLPerfGovReportDataModule
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs
from nemo.collections.llm.gpt.model.llama import *
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.run.plugins import MemoryProfilePlugin, NsysPlugin, PerfEnvPlugin
from ..argument_parser import parse_cli_args
from ..executors import slurm_executor
from ..helpers import args_sanity_check
from ..utils import import_ckpt_experiment
NUM_NODES = 1
NUM_GPUS_PER_NODE = 8
MICRO_BATCH_SIZE = 1
GLOBAL_BATCH_SIZE = 8
TP_SIZE = 4 # tp_comm_overlap_cfg may need to be re-tuned if altering
PP_SIZE = 1
CP_SIZE = 1
MAX_STEPS = 100
SEQ_LENGTH = 8192
HF_MODEL_URI = "meta-llama/Llama-2-70b-hf"
SKIP_IMPORT = False
def mlperf_lora_llama2_70b_recipe(
num_nodes: int,
num_gpus_per_node: int,
mbs: int,
gbs: int,
tp_size: int,
pp_size: int,
cp_size: int,
max_steps: int,
):
"""
A recipe for llama2 70B LoRA training based on performance optimizations used in NVIDIA's MLPerf submissions.
Some optimizations are experimental; use caution.
Last updated for the 4.1 submission. Designed for 8x H100, tp4pp1cp1.
"""
tokenizer = run.Config(
AutoTokenizer,
pretrained_model_name=HF_MODEL_URI,
)
packed_sequence_specs = run.Config(
PackedSequenceSpecs,
packed_sequence_size=SEQ_LENGTH,
)
datamodule = run.Config(
MLPerfGovReportDataModule,
seq_length=SEQ_LENGTH,
tokenizer=tokenizer,
micro_batch_size=mbs,
global_batch_size=gbs,
persistent_workers=True,
packed_sequence_specs=packed_sequence_specs,
dataset_kwargs={
"return_cu_seqlen": False,
},
)
optimizer_config = run.Config(
OptimizerConfig,
lr=2e-4,
clip_grad=0.3,
weight_decay=0.0001,
bf16=True,
params_dtype=torch.bfloat16,
use_distributed_optimizer=True,
)
scheduler = run.Config(
CosineAnnealingScheduler,
max_steps=max_steps,
warmup_steps=500,
constant_steps=0,
min_lr=0,
)
optimizer = run.Config(
nl.MegatronOptimizerModule,
config=optimizer_config,
lr_scheduler=scheduler,
)
peft = run.Config(
llm.peft.LoRA,
dim=16,
alpha=32,
dropout=0.1,
a2a_experimental=True,
dropout_position="pre",
lora_A_init_method="kaiming",
target_modules=['linear_proj', 'linear_qkv'],
)
llama2_config = run.Config(
llm.Llama2Config70B,
seq_length=SEQ_LENGTH,
tp_comm_overlap_disable_qkv=True,
fp8_dot_product_attention=1,
cross_entropy_loss_fusion=False,
disable_parameter_transpose_cache=True,
external_cuda_graph=False,
enable_cuda_graph=False,
)
llama2_config.microbatch_group_size_per_vp_stage = 1
model = run.Config(
MLPerfLoRALlamaModel,
llama2_config,
)
resume = llm.recipes.finetune_default.nemo_resume(HF_MODEL_URI)
from megatron.core.distributed import DistributedDataParallelConfig
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tp_size,
pipeline_model_parallel_size=pp_size,
context_parallel_size=cp_size,
sequence_parallel=1 if (tp_size > 1) else 0,
pipeline_dtype=torch.bfloat16,
ckpt_load_directly_on_device=False,
ckpt_parallel_load=False,
ckpt_load_strictness="log_all",
gradient_as_bucket_view=True,
use_te_rng_tracker=True,
ddp=run.Config(
DistributedDataParallelConfig,
use_distributed_optimizer=True,
),
)
precision = run.Config(
nl.MegatronMixedPrecision,
precision="bf16-mixed",
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
autocast_enabled=True,
grad_reduce_in_fp32=False,
fp8="hybrid",
fp8_amax_history_len=32,
fp8_amax_compute_algo='max',
fp8_param_gather=True,
fp8_dot_product_attention=1,
)
# Communication overlap settings
tp_comm_overlap_cfg = None
ub_tp_comm_overlap = tp_size > 1
if ub_tp_comm_overlap:
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
BulkOverlapCfg,
PipelineOverlapCfg,
RingExchangeOverlapCfg,
TransformerLayerTPOverlapCfg,
)
tp_comm_overlap_cfg = TransformerLayerTPOverlapCfg(
qkv_fprop=RingExchangeOverlapCfg(),
fc1_fprop=RingExchangeOverlapCfg(),
proj_dgrad=RingExchangeOverlapCfg(),
fc2_dgrad=RingExchangeOverlapCfg(),
proj_fprop=PipelineOverlapCfg(
num_sm=32, cga_size=2, num_splits=4, set_sm_margin=True, atomic_gemm=True, fp8_buf=True
),
fc2_fprop=PipelineOverlapCfg(
num_sm=16, cga_size=2, num_splits=4, set_sm_margin=True, atomic_gemm=True, fp8_buf=False
),
qkv_dgrad=BulkOverlapCfg(num_sm=4, cga_size=2, set_sm_margin=False),
fc1_dgrad=RingExchangeOverlapCfg(cga_size=2, set_sm_margin=True),
qkv_wgrad=None,
fc1_wgrad=None,
)
overlap_callback = run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=ub_tp_comm_overlap,
tp_comm_overlap_cfg=tp_comm_overlap_cfg,
overlap_grad_reduce=False,
overlap_param_gather=False,
overlap_param_gather_with_optimizer_step=False,
)
# Disable garbage collection
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
gc_callback = run.Config(
GarbageCollectionCallback,
max_steps + 1,
max_steps + 1,
)
# Log step times
from nemo.utils.exp_manager import TimingCallback
timing_callback = run.Config(TimingCallback)
trainer = run.Config(
nl.Trainer,
max_steps=max_steps,
val_check_interval=(max_steps - 1),
limit_val_batches=0,
devices=num_gpus_per_node,
num_nodes=num_nodes,
accelerator="gpu",
strategy=strategy,
plugins=precision,
num_sanity_val_steps=0,
accumulate_grad_batches=1,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=True,
use_distributed_sampler=False,
log_every_n_steps=10,
callbacks=[overlap_callback, gc_callback, timing_callback],
)
recipe = run.Partial(
llm.finetune,
model=model,
trainer=trainer,
data=datamodule,
optim=optimizer,
resume=resume,
peft=peft,
)
recipe.model.tokenizer = tokenizer
return recipe
if __name__ == "__main__":
args = parse_cli_args().parse_args()
args_sanity_check(args)
if args.finetuning.lower() != "lora" or args.compute_dtype != "fp8":
raise ValueError(
"This example assumes LoRA and fp8; instead got "
f"--finetuning={args.finetuning} --compute_dtype={args.compute_dtype}"
)
if args.virtual_pipeline_parallel_size or args.expert_parallel_size:
raise ValueError(
"This example does not support virtual pipeline parallel or expert parallel; "
f"got --virtual_pipeline_parallel_size={args.virtual_pipeline_parallel_size} "
f"--expert_parallel_size={args.expert_parallel_size}"
)
num_gpus_per_node = NUM_GPUS_PER_NODE if args.gpus_per_node is None else args.gpus_per_node
num_gpus = NUM_NODES * num_gpus_per_node if args.num_gpus is None else args.num_gpus
num_nodes = -(num_gpus // -num_gpus_per_node)
mbs = MICRO_BATCH_SIZE if args.micro_batch_size is None else args.micro_batch_size
gbs = GLOBAL_BATCH_SIZE if args.global_batch_size is None else args.global_batch_size
tp_size = TP_SIZE if args.tensor_parallel_size is None else args.tensor_parallel_size
pp_size = PP_SIZE if args.pipeline_parallel_size is None else args.pipeline_parallel_size
cp_size = CP_SIZE if args.context_parallel_size is None else args.context_parallel_size
exp_name = "_".join(
[
args.finetuning.lower(),
"llama2_70b",
args.compute_dtype,
f"{num_nodes}nodes",
f"tp{tp_size}_pp{pp_size}_cp{cp_size}",
f"{mbs}mbs_{gbs}gbs",
]
)
env_vars = {
# pylint: disable=C0301
"CUDA_DEVICE_MAX_CONNECTIONS": (
"32" if args.gpu.lower() in ['b200', 'gb200'] else "1"
), # Limit GPUs to one compute stream so that kernels will be executed in consistent order, for best performance with communication overlap configs
# pylint: disable=C0301
"CUBLAS_FORCE_XMMA_KERNEL_INIT": "DEVICE", # Use a device kernel instead of memset for matrix multiply initialization, which can help reduce CPU-side overhead
"CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT": "0", # Reduce memory used by cuDNN attention
"NVTE_FP8_DPA_BWD": "1", # Enable TransformerEngine FP8 attention for bprop
# pylint: disable=C0301
"NVTE_RS_STRIDED_ATOMIC": "2", # Reduce-scatter communication will be done as a single kernel (with userbuffer TP communication overlap)
"NCCL_MIN_CTAS": "32", # Increase CTA resources available to NCCL to improve communication perf
"NCCL_MIN_P2P_NCHANNELS": "32", # Likewise, increase P2P channels availabe to NCCL
"NCCL_NCHANNELS_PER_NET_PEER": "32", # Likewise, increase per-peer P2P channel limit
"NCCL_NVLS_ENABLE": "0", # Disable NVLink SHARP to save memory
"TORCH_NCCL_AVOID_RECORD_STREAMS": "1", # Disable caching NCCL communication buffer memory
}
executor = slurm_executor(
args.gpu.lower(),
args.account,
args.partition,
args.log_dir,
num_nodes,
num_gpus_per_node,
args.time_limit,
args.container_image,
custom_mounts=args.custom_mounts,
custom_env_vars=env_vars,
hf_token=args.hf_token,
nemo_home=args.nemo_home,
wandb_key=args.wandb_key,
network='sharp' if args.use_sharp else None,
)
recipe = mlperf_lora_llama2_70b_recipe(
num_nodes,
num_gpus_per_node,
mbs,
gbs,
tp_size,
pp_size,
cp_size,
args.max_steps,
)
if not args.tensorboard:
recipe.log.tensorboard = None
recipe.trainer.logger = False
else:
recipe.log.log_dir = "/nemo_run/lightning_logs"
if args.wandb:
assert args.wandb_prj_name is not None
assert args.wandb_job_name is not None
from nemo.collections.llm.recipes.log.default import wandb_logger
recipe.log.wandb = wandb_logger(project=args.wandb_prj_name, name=args.wandb_job_name)
plugins = [
PerfEnvPlugin(
enable_vboost=True,
nccl_pp_comm_chunksize=2097152 if PP_SIZE > 1 else None,
gpu_sm100_or_newer=(args.gpu.lower() in ['b200', 'gb200']),
)
]
if args.enable_nsys:
plugins.append(NsysPlugin(start_step=5, end_step=6))
if args.enable_memory_profile:
assert args.memory_profile_out_path is not None
plugins.append(MemoryProfilePlugin(dir=args.memory_profile_out_path))
with run.Experiment(exp_name) as exp:
if not SKIP_IMPORT:
assert args.hf_token is not None, "HF token is required for importing checkpoint from HuggingFace"
exp.add(
*import_ckpt_experiment(
executor, run.Config(LlamaModel, config=run.Config(Llama2Config70B)), source=f"hf://{HF_MODEL_URI}"
)
)
exp.add(
recipe,
executor=executor,
name=exp_name,
plugins=plugins,
)
if not args.dryrun:
exp.run(sequential=True, detach=True)
else:
exp.dryrun()