| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Example: |
| python scripts/vlm/gemma3vl_generate.py --local_model_path="path/to/converted_nemo_checkpoint" |
| """ |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| from megatron.core import parallel_state |
| from megatron.core.pipeline_parallel.schedules import get_forward_backward_func |
|
|
| import nemo.lightning as nl |
| from nemo.collections.common.tokenizers import AutoTokenizer |
| from nemo.collections.vlm import Gemma3VLModel |
| from nemo.collections.vlm.inference.base import _setup_trainer_and_restore_model |
| from nemo.lightning import io |
| from nemo.lightning.ckpt_utils import ckpt_to_context_subdir |
| from nemo.utils.get_rank import get_last_rank |
|
|
|
|
| class SingleBatchIterator: |
| def __init__(self, pixel_values, input_ids, position_ids): |
| self.batch = dict( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| position_ids=position_ids, |
| ) |
| self._yielded = False |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| if self._yielded: |
| raise StopIteration |
| self._yielded = True |
| return self.batch |
|
|
|
|
| def gemma3_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: |
| batch = next(data_iterator) |
| forward_args = { |
| "input_ids": batch["input_ids"], |
| "position_ids": batch["position_ids"], |
| "pixel_values": batch.get("pixel_values", None), |
| "loss_mask": batch.get("loss_mask", None), |
| "labels": batch.get("labels", None), |
| } |
|
|
| def loss_func(x, **kwargs): |
| return x |
|
|
| return model(**forward_args), loss_func |
|
|
|
|
| def main(args) -> None: |
| |
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=args.tp, |
| pipeline_model_parallel_size=args.pp, |
| sequence_parallel=args.tp > 1, |
| ckpt_include_optimizer=False, |
| ckpt_load_strictness="log_all", |
| pipeline_dtype=torch.bfloat16, |
| ) |
| trainer = nl.Trainer( |
| devices=min(args.tp * args.pp, 8), |
| num_nodes=max(args.tp * args.pp // 8, 1), |
| accelerator="gpu", |
| strategy=strategy, |
| plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), |
| enable_checkpointing=False, |
| ) |
|
|
| if args.local_model_path: |
| path = Path(args.local_model_path) |
| model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model") |
| _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) |
| else: |
| fabric = trainer.to_fabric() |
| model = fabric.import_model("hf://google/gemma-3-4b-it", Gemma3VLModel) |
|
|
| model = model.module.cuda() |
| model.eval() |
|
|
| from transformers import AutoProcessor |
|
|
| model_id = 'google/gemma-3-4b-it' |
| processor = AutoProcessor.from_pretrained(model_id) |
| gemma_tokenizer = AutoTokenizer(model_id) |
| hf_tokenizer = gemma_tokenizer.tokenizer |
|
|
| messages = [ |
| {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "image", |
| "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG", |
| }, |
| {"type": "text", "text": "What animal is on the candy?"}, |
| ], |
| }, |
| ] |
|
|
| inputs = processor.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
|
|
| input_ids = inputs["input_ids"].cuda() |
| |
| pixel_values = inputs["pixel_values"].cuda().unsqueeze(0).to(dtype=torch.bfloat16) |
| position_ids = ( |
| torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) |
| ) |
| generated_ids = input_ids.clone() |
|
|
| stop_tokens = [1, 126] |
| |
| for step in range(20): |
| with torch.no_grad(): |
| if torch.distributed.get_rank() == 0: |
| print(step) |
| fwd_bwd_function = get_forward_backward_func() |
|
|
| iterator = SingleBatchIterator(pixel_values, input_ids, position_ids) |
|
|
| output = fwd_bwd_function( |
| forward_step_func=gemma3_forward_step, |
| data_iterator=iterator, |
| model=model, |
| num_microbatches=1, |
| forward_only=True, |
| seq_length=input_ids.size(1), |
| micro_batch_size=1, |
| collect_non_loss_data=True, |
| ) |
| if isinstance(output, list) and len(output) > 0: |
| output = output[0] |
|
|
| if parallel_state.is_pipeline_last_stage(): |
| world_size = parallel_state.get_tensor_model_parallel_world_size() |
| gathered_tensors = [torch.zeros_like(output) for _ in range(world_size)] |
| |
| dist.all_gather(gathered_tensors, output, group=parallel_state.get_tensor_model_parallel_group()) |
| |
| output = torch.cat(gathered_tensors, dim=2) |
|
|
| next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) |
| else: |
| next_token_ids = torch.ones((1, 1), device=generated_ids.device, dtype=generated_ids.dtype) |
|
|
| torch.distributed.broadcast(next_token_ids, get_last_rank()) |
| generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) |
|
|
| input_ids = generated_ids |
| position_ids = ( |
| torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) |
| .unsqueeze(0) |
| .expand_as(input_ids) |
| ) |
|
|
| |
| if next_token_ids.item() in stop_tokens: |
| break |
|
|
| generated_texts = hf_tokenizer.decode(list(generated_ids[0])) |
| if torch.distributed.get_rank() == 0: |
| print("======== GENERATED TEXT OUTPUT ========") |
| print(f"{generated_texts}") |
| print("=======================================") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Gemma3 Multimodal Inference") |
| parser.add_argument( |
| "--local_model_path", |
| type=str, |
| default=None, |
| help="Local path to the model if not loading from Hugging Face.", |
| ) |
| parser.add_argument('--tp', default=1) |
| parser.add_argument('--pp', default=1) |
| args = parser.parse_args() |
|
|
| main(args) |
|
|