| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Usage example: |
| python scripts/vlm/gemma3_automodel.py --model llava-hf/llava-v1.6-mistral-7b-hf --data_path naver-clova-ix/cord-v2 |
| |
| Used Automodel for training and finetuning HF models. More details can be found at: |
| https://docs.nvidia.com/nemo-framework/user-guide/latest/automodel/index.html |
| """ |
|
|
|
|
| import fiddle as fdl |
| import lightning.pytorch as pl |
| import torch |
| from lightning.pytorch.loggers import WandbLogger |
|
|
| from nemo import lightning as nl |
| from nemo.automodel.dist_utils import FirstRankPerNode |
| from nemo.collections import llm, vlm |
| from nemo.collections.vlm.hf.data.automodel_datasets import ( |
| mk_hf_vlm_dataset_cord_v2, |
| mk_hf_vlm_dataset_fineweb_edu, |
| mk_hf_vlm_dataset_rdr, |
| ) |
|
|
|
|
| def make_strategy(strategy, model, devices, num_nodes, adapter_only=False): |
| if strategy == 'auto': |
| return pl.strategies.SingleDeviceStrategy( |
| device='cuda:0', |
| checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only), |
| ) |
| elif strategy == 'ddp': |
| return pl.strategies.DDPStrategy( |
| checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only), |
| find_unused_parameters=True, |
| ) |
| elif strategy == 'fsdp2': |
| return nl.FSDP2Strategy( |
| data_parallel_size=devices * num_nodes, |
| tensor_parallel_size=1, |
| checkpoint_io=model.make_checkpoint_io(adapter_only=adapter_only), |
| ) |
| else: |
| raise NotImplementedError("Encountered unknown strategy") |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--name', default='gemma3_automodel') |
| parser.add_argument('--model', type=str, default="google/gemma-3-4b-it") |
| parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp2']) |
| parser.add_argument('--devices', default=1, type=int) |
| parser.add_argument('--num-nodes', default=1, type=int) |
| parser.add_argument('--mbs', default=1, type=int) |
| parser.add_argument('--gbs', default=4, type=int) |
| parser.add_argument( |
| "--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints" |
| ) |
| parser.add_argument('--accelerator', default='gpu', choices=['gpu']) |
| parser.add_argument('--max-steps', type=int, default=1000) |
| parser.add_argument('--wandb-project', type=str, default=None) |
| parser.add_argument('--disable-ckpt', action='store_false') |
| parser.add_argument('--use-4bit', help="Load model in 4bit", action="store_true") |
| parser.add_argument( |
| "--data_path", |
| type=str, |
| default="quintend/rdr-items", |
| help="Path to the dataset. Can be a local path or a HF dataset name", |
| ) |
| parser.add_argument("--peft", type=str, default="none", choices=["lora", "none"], help="Which peft to use") |
| parser.add_argument("--freeze-vision-model", action="store_true", help="Freeze the vision model parameters") |
| parser.add_argument("--freeze-language-model", action="store_true", help="Freeze the language model parameters") |
| args = parser.parse_args() |
|
|
| dataset_fn = None |
| |
| |
| if "rdr-items" in args.data_path: |
| dataset_fn = mk_hf_vlm_dataset_rdr |
| elif "cord-v2" in args.data_path: |
| dataset_fn = mk_hf_vlm_dataset_cord_v2 |
| elif "fineweb-edu" in args.data_path: |
| dataset_fn = mk_hf_vlm_dataset_fineweb_edu |
| else: |
| raise NotImplementedError |
|
|
| processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model) |
|
|
| with FirstRankPerNode(): |
| dataset = dataset_fn(args.data_path, processor, args.mbs, args.gbs) |
|
|
| model_kwargs = {} |
|
|
| |
| if "google/gemma-3" in args.model: |
| model_kwargs["attn_implementation"] = "eager" |
|
|
| |
| |
| if args.freeze_language_model: |
| raise ValueError("Freezing language model is not supported for current version of VLM automodel") |
|
|
| model = vlm.HFAutoModelForImageTextToText( |
| args.model, |
| load_in_4bit=args.use_4bit, |
| processor=processor, |
| freeze_language_model=args.freeze_language_model, |
| freeze_vision_model=args.freeze_vision_model, |
| **model_kwargs, |
| ) |
|
|
| peft = None |
| if args.peft == 'lora': |
| peft = llm.peft.LoRA( |
| target_modules=['*_proj'], |
| dim=16, |
| lora_dtype=torch.bfloat16 if args.use_4bit else None, |
| ) |
| nemo_logger = nl.NeMoLogger( |
| log_dir=args.log_dir, |
| name=args.name, |
| wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, |
| ) |
|
|
| llm.finetune( |
| model=model, |
| data=dataset, |
| trainer=nl.Trainer( |
| devices=args.devices, |
| max_steps=args.max_steps, |
| accelerator=args.accelerator, |
| strategy=make_strategy(args.strategy, model, args.devices, args.num_nodes, adapter_only=False), |
| log_every_n_steps=1, |
| limit_val_batches=0.0, |
| num_sanity_val_steps=0, |
| accumulate_grad_batches=max(1, args.gbs // args.mbs), |
| gradient_clip_val=1, |
| use_distributed_sampler=False, |
| enable_checkpointing=args.disable_ckpt, |
| precision='bf16-mixed', |
| num_nodes=args.num_nodes, |
| ), |
| |
| |
| optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), |
| log=nemo_logger, |
| peft=peft, |
| ) |
|
|