| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import torch |
| from lightning.pytorch.loggers import TensorBoardLogger |
| from megatron.core.distributed import DistributedDataParallelConfig |
| from megatron.core.optimizer import OptimizerConfig |
|
|
| from nemo import lightning as nl |
| from nemo.collections import llm, vlm |
|
|
| from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer |
| from nemo.collections.vlm.gemma3vl.data.mock import Gemma3VLMockDataModule |
| from nemo.lightning.pytorch.optim import CosineAnnealingScheduler |
| from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule |
| from nemo.utils.exp_manager import TimingCallback |
|
|
|
|
| NAME = "gemma3vl_4b" |
|
|
| HF_MODEL_NAME = "google/gemma-3-4b-it" |
|
|
|
|
| def finetune_recipe( |
| log_dir: str, |
| num_nodes: int = 1, |
| num_gpus_per_node: int = 1, |
| ): |
| """Gemma3 VL finetune""" |
|
|
| max_sequence_length = 512 |
|
|
| tokenizer = AutoTokenizer(HF_MODEL_NAME) |
| language_transformer_config = llm.Gemma3Config4B(seq_length=max_sequence_length) |
| vision_transformer_config = vlm.Gemma3VLVisionConfig() |
| vision_projection_config = vlm.Gemma3VLMultimodalProjectorConfig( |
| input_size=vision_transformer_config.hidden_size, |
| hidden_size=language_transformer_config.hidden_size, |
| ) |
| gemma3vl_config = vlm.Gemma3VLConfig( |
| language_transformer_config=language_transformer_config, |
| vision_transformer_config=vision_transformer_config, |
| vision_projection_config=vision_projection_config, |
| freeze_language_model=False, |
| freeze_vision_model=True, |
| freeze_vision_projection=True, |
| ) |
| model = vlm.Gemma3VLModel(gemma3vl_config, tokenizer=tokenizer) |
| llm.import_ckpt(model=model, source=f"hf://{HF_MODEL_NAME}") |
|
|
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=2, |
| pipeline_model_parallel_size=1, |
| pipeline_dtype=torch.bfloat16, |
| virtual_pipeline_model_parallel_size=None, |
| encoder_tensor_model_parallel_size=0, |
| encoder_pipeline_model_parallel_size=0, |
| context_parallel_size=1, |
| sequence_parallel=True, |
| ckpt_async_save=True, |
| ckpt_parallel_save=True, |
| ckpt_parallel_load=True, |
| ckpt_parallel_save_optim=True, |
| ckpt_load_strictness="log_all", |
| gradient_as_bucket_view=True, |
| ddp=DistributedDataParallelConfig( |
| check_for_nan_in_grad=True, |
| grad_reduce_in_fp32=True, |
| overlap_grad_reduce=True, |
| overlap_param_gather=True, |
| average_in_collective=True, |
| ), |
| ) |
|
|
| trainer = nl.Trainer( |
| accelerator="gpu", |
| devices=num_gpus_per_node, |
| num_nodes=num_nodes, |
| max_steps=5, |
| limit_val_batches=0, |
| val_check_interval=10, |
| num_sanity_val_steps=0, |
| log_every_n_steps=1, |
| strategy=strategy, |
| accumulate_grad_batches=1, |
| use_distributed_sampler=False, |
| plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), |
| enable_checkpointing=False, |
| callbacks=[TimingCallback()], |
| ) |
|
|
| opt_config = OptimizerConfig( |
| optimizer="adam", |
| lr=3e-4, |
| weight_decay=0.1, |
| bf16=True, |
| fp16=False, |
| adam_beta1=0.9, |
| adam_beta2=0.95, |
| adam_eps=1e-5, |
| use_distributed_optimizer=True, |
| clip_grad=1.0, |
| ) |
| lr_scheduler = CosineAnnealingScheduler( |
| warmup_steps=2000, |
| constant_steps=0, |
| min_lr=3e-5, |
| ) |
| opt = MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler) |
|
|
| data = Gemma3VLMockDataModule( |
| seq_length=max_sequence_length, |
| global_batch_size=4, |
| micro_batch_size=1, |
| tokenizer=tokenizer, |
| num_workers=4, |
| ) |
|
|
| ckpt = nl.ModelCheckpoint( |
| save_top_k=1, |
| save_last=True, |
| save_optim_on_train_end=False, |
| filename="{val_loss:.2f}-{step}-{consumed_samples}", |
| ) |
| tb = TensorBoardLogger( |
| save_dir="tensorboard", |
| name="", |
| ) |
| logger = nl.NeMoLogger( |
| explicit_log_dir=log_dir, |
| log_global_rank_0_only=True, |
| update_logger_directory=True, |
| ckpt=ckpt, |
| tensorboard=tb, |
| ) |
|
|
| resume = nl.AutoResume( |
| resume_if_exists=True, |
| resume_ignore_no_checkpoint=True, |
| ) |
|
|
| llm.finetune( |
| model=model, |
| data=data, |
| trainer=trainer, |
| log=logger, |
| optim=opt, |
| resume=resume, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| finetune_recipe(f"/tmp/{NAME}", num_gpus_per_node=2) |
|
|