| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import shutil |
| import tempfile |
|
|
| import pytest |
| import torch |
| import torch.distributed |
| import torch.multiprocessing as mp |
| from torch.distributed import init_device_mesh |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import MixedPrecision, ShardingStrategy |
| from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config |
|
|
| from verl.utils.activation_offload import enable_activation_offloading |
| from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager |
| from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy |
|
|
|
|
| def create_random_input_ids(batch_size, seq_len, vocab_size): |
| from flash_attn.bert_padding import unpad_input |
|
|
| from verl.utils.model import compute_position_id_with_mask, create_random_mask |
|
|
| input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") |
|
|
| attention_mask = create_random_mask( |
| input_ids, max_ratio_of_left_padding=0.1, min_ratio_of_valid_token=0.5, max_ratio_of_valid_token=0.7 |
| ) |
| position_ids = compute_position_id_with_mask(attention_mask) |
|
|
| input_ids = unpad_input(input_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1) |
| position_ids = unpad_input(position_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1) |
| return input_ids, position_ids |
|
|
|
|
| def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"): |
| torch.cuda.set_device(rank) |
| torch.distributed.init_process_group( |
| backend="nccl", |
| init_method=f"file://{rendezvous_file}", |
| rank=rank, |
| world_size=world_size, |
| ) |
| device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) |
|
|
| model_name = "Qwen/Qwen2.5-0.5B-Instruct" |
| config = Qwen2Config(num_hidden_layers=4) |
|
|
| with torch.device("cuda"): |
| model = AutoModelForCausalLM.from_config( |
| config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" |
| ) |
| model = model.to(device="cuda") |
|
|
| |
| mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) |
|
|
| if strategy == "fsdp": |
| model = FSDP( |
| model, |
| use_orig_params=False, |
| device_id=torch.cuda.current_device(), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mixed_precision, |
| device_mesh=device_mesh, |
| auto_wrap_policy=get_fsdp_wrap_policy(module=model), |
| ) |
| else: |
| mp_policy = MixedPrecisionPolicy( |
| param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True |
| ) |
| fsdp_kwargs = { |
| "mesh": device_mesh, |
| "mp_policy": mp_policy, |
| } |
| apply_fsdp2(model, fsdp_kwargs, {}) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| checkpoint_manager = FSDPCheckpointManager( |
| model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer |
| ) |
|
|
| |
| batch_size = 2 |
| seq_len = 32 |
| vocab_size = 32000 |
| |
| input_ids1, position_ids1 = create_random_input_ids(batch_size, seq_len, vocab_size) |
|
|
| |
| input_ids2, position_ids2 = create_random_input_ids(batch_size, seq_len, vocab_size) |
|
|
| |
| outputs1 = model(input_ids=input_ids1, position_ids=position_ids1) |
| loss1 = outputs1.logits.mean() |
| loss1.backward() |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
|
|
| |
| temp_dir = tempfile.mkdtemp() |
| checkpoint_path = os.path.join(temp_dir, "checkpoint") |
| checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) |
|
|
| |
| outputs2 = model(input_ids=input_ids2, position_ids=position_ids2) |
| loss2 = outputs2.logits.mean() |
| loss2.backward() |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
|
|
| |
| with torch.no_grad(): |
| logits_without_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits |
|
|
| |
| enable_activation_offloading(model, strategy=strategy) |
| checkpoint_manager.load_checkpoint(checkpoint_path) |
|
|
| |
| outputs3 = model(input_ids=input_ids2, position_ids=position_ids2) |
| loss3 = outputs3.logits.mean() |
| loss3.backward() |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
|
|
| |
| with torch.no_grad(): |
| logits_with_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits |
|
|
| |
| torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0) |
| print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!") |
|
|
| |
| shutil.rmtree(temp_dir) |
| torch.distributed.barrier() |
| torch.distributed.destroy_process_group() |
|
|
|
|
| @pytest.mark.parametrize("world_size", (2, 4)) |
| @pytest.mark.parametrize("strategy", ("fsdp", "fsdp2")) |
| def test_activation_offloading(world_size, strategy, tmp_path): |
| rendezvous_file = str(tmp_path / "rdzv_file") |
| os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) |
|
|
| mp.spawn( |
| fn=_fsdp_activation_offloading_test, |
| args=(world_size, rendezvous_file, strategy), |
| nprocs=world_size, |
| join=True, |
| ) |
|
|