| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| import argparse |
| import torch |
| import torch.distributed |
| from megatron.core.inference.common_inference_params import CommonInferenceParams |
|
|
| import nemo.lightning as nl |
| from nemo.collections.llm import api |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description='Train a small T5 model using NeMo 2.0') |
| parser.add_argument('--devices', type=int, help="Number of devices to use for training.") |
| parser.add_argument('--checkpoint-path', type=str, help="Path to trained model.") |
| parser.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.') |
| parser.add_argument("--top_k", type=int, default=1, help='Top k sampling.') |
| parser.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') |
| parser.add_argument( |
| '--no-space-before-mask', |
| action='store_true', |
| help="Flag to not having space before <mask>. E.g., as in Tiktokenizer or sentencepiece case.", |
| ) |
| parser.add_argument( |
| "--num-tokens-to-generate", type=int, default=30, help='Number of tokens to generate for each prompt.' |
| ) |
| parser.add_argument( |
| "--prompts", |
| metavar='N', |
| type=str, |
| nargs='+', |
| help='Prompts with each prompt within quotes and seperated by space.', |
| ) |
| parser.add_argument( |
| "--encoder-prompts", |
| metavar='N', |
| type=str, |
| nargs='+', |
| help='Encoder input prompts with each prompt within quotes and seperated by space.', |
| ) |
| parser.add_argument("--max-batch-size", type=int, default=1, help='Max number of prompts to process at once.') |
|
|
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
|
|
| args = get_args() |
|
|
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=1, |
| pipeline_model_parallel_size=1, |
| context_parallel_size=1, |
| sequence_parallel=False, |
| setup_optimizers=False, |
| store_optimizer_states=False, |
| ) |
|
|
| trainer = nl.Trainer( |
| accelerator="gpu", |
| devices=args.devices, |
| num_nodes=1, |
| strategy=strategy, |
| plugins=nl.MegatronMixedPrecision( |
| precision="bf16-mixed", |
| params_dtype=torch.bfloat16, |
| pipeline_dtype=torch.bfloat16, |
| autocast_enabled=False, |
| grad_reduce_in_fp32=False, |
| ), |
| ) |
| prompts = [ |
| "", |
| "", |
| "", |
| ] |
| if args.no_space_before_mask: |
| encoder_prompts = [ |
| "Hi<mask>. Hello, how are <mask>?", |
| "How<mask> r's are in the<mask> 'strawberry'? Can you<mask> me?", |
| "Which number is<mask>? 10.119<mask> 10.19?", |
| ] |
| else: |
| encoder_prompts = [ |
| "Hi <mask>. Hello, how are <mask>?", |
| "How <mask> r's are in the <mask> 'strawberry'? Can you <mask> me?", |
| "Which number is <mask>? 10.119 <mask> 10.19?", |
| ] |
|
|
| results = api.generate( |
| path=args.checkpoint_path, |
| prompts=prompts, |
| encoder_prompts=encoder_prompts, |
| trainer=trainer, |
| add_BOS=True, |
| inference_params=CommonInferenceParams( |
| temperature=args.temperature, top_k=args.top_k, num_tokens_to_generate=args.num_tokens_to_generate |
| ), |
| text_only=True, |
| ) |
| if torch.distributed.get_rank() == 0: |
| for i, r in enumerate(results): |
| print(prompts[i]) |
| print("*" * 50) |
| print(r) |
| print("\n\n") |
|
|