| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import sys |
|
|
| from nemo.deploy.multimodal import NemoQueryMultimodal |
|
|
|
|
| def get_args(argv): |
| parser = argparse.ArgumentParser( |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| description=f"Query Triton Multimodal server", |
| ) |
| parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server") |
| parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model") |
| parser.add_argument("-mt", "--model_type", required=True, type=str, help="Type of the triton model") |
| parser.add_argument("-int", "--input_text", required=True, type=str, help="Input text") |
| parser.add_argument("-im", "--input_media", required=True, type=str, help="File path of input media") |
| parser.add_argument("-bs", "--batch_size", default=1, type=int, help="Batch size") |
| parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length") |
| parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k") |
| parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p") |
| parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature") |
| parser.add_argument("-rp", "--repetition_penalty", default=1.0, type=float, help="repetition_penalty") |
| parser.add_argument("-nb", "--num_beams", default=1, type=int, help="num_beams") |
| parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server") |
| parser.add_argument( |
| "-lt", |
| "--lora_task_uids", |
| default=None, |
| type=str, |
| nargs="+", |
| help="The list of LoRA task uids; use -1 to disable the LoRA module", |
| ) |
|
|
| args = parser.parse_args(argv) |
| return args |
|
|
|
|
| if __name__ == '__main__': |
| args = get_args(sys.argv[1:]) |
| nq = NemoQueryMultimodal(url=args.url, model_name=args.model_name, model_type=args.model_type) |
| output = nq.query( |
| input_text=args.input_text, |
| input_media=args.input_media, |
| batch_size=args.batch_size, |
| max_output_len=args.max_output_len, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| temperature=args.temperature, |
| repetition_penalty=args.repetition_penalty, |
| num_beams=args.num_beams, |
| init_timeout=args.init_timeout, |
| lora_uids=args.lora_task_uids, |
| ) |
| print(output) |
|
|