| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Gemma3 language model generate""" |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| from nemo import lightning as nl |
| from nemo.collections.llm.gpt.model.gemma3 import Gemma3Model |
|
|
| HF_MODEL_NAME = "google/gemma-3-1b-it" |
|
|
|
|
| def main(): |
| """Entrypoint""" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME) |
|
|
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=1, |
| pipeline_model_parallel_size=1, |
| pipeline_dtype=torch.bfloat16, |
| virtual_pipeline_model_parallel_size=None, |
| context_parallel_size=1, |
| expert_model_parallel_size=1, |
| sequence_parallel=False, |
| setup_optimizers=False, |
| store_optimizer_states=False, |
| ) |
|
|
| trainer = nl.Trainer( |
| accelerator="gpu", |
| devices=1, |
| num_nodes=1, |
| strategy=strategy, |
| plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), |
| enable_checkpointing=False, |
| ) |
| fabric = trainer.to_fabric() |
| model = fabric.import_model(f"hf://{HF_MODEL_NAME}", Gemma3Model) |
| model = model.module.cuda() |
| model.eval() |
|
|
| messages = [ |
| [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "Who are you?"}, |
| ], |
| }, |
| ], |
| ] |
| inputs = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ).to(model.device) |
|
|
| |
|
|
| with torch.no_grad(): |
| input_ids = inputs['input_ids'].clone().to("cuda") |
| generated_ids = input_ids |
| for _ in range(10): |
| seq_len = input_ids[0].shape[0] |
| position_ids = torch.arange(seq_len, dtype=torch.int64).to("cuda") |
| output = model( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| attention_mask=None, |
| ) |
| next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) |
| generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) |
| input_ids = generated_ids |
|
|
| outputs = tokenizer.batch_decode(generated_ids) |
| |
| print(outputs) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|