| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
|
|
| import tensorrt as trt |
|
|
| from nemo.collections.llm.gpt.model.hf_llama_embedding import get_llama_bidirectional_hf_model |
| from nemo.export.onnx_llm_exporter import OnnxLLMExporter |
| from nemo.utils import logging |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description='Test ONNX and TensorRT export for LLM embedding models.') |
| parser.add_argument('--hf_model_path', type=str, required=True, help="Hugging Face model id or path.") |
| parser.add_argument('--pooling_strategy', type=str, default="avg", help="Pooling strategy for the model.") |
| parser.add_argument("--normalize", default=False, action="store_true", help="Normalize the embeddings or not.") |
| parser.add_argument('--onnx_export_path', type=str, default="/tmp/onnx_model/", help="Path to store ONNX model.") |
| parser.add_argument('--onnx_opset', type=int, default=17, help="ONNX version to use for export.") |
| parser.add_argument('--trt_model_path', type=str, default="/tmp/trt_model/", help="Path to store TensorRT model.") |
| parser.add_argument( |
| "--trt_version_compatible", |
| default=False, |
| action="store_true", |
| help="Whether to generate version compatible TensorRT models.", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def export_onnx_trt(args): |
| |
| model, tokenizer = get_llama_bidirectional_hf_model( |
| model_name_or_path=args.hf_model_path, |
| normalize=args.normalize, |
| pooling_mode=args.pooling_strategy, |
| trust_remote_code=True, |
| ) |
|
|
| input_names = ["input_ids", "attention_mask", "dimensions"] |
| dynamic_axes_input = { |
| "input_ids": {0: "batch_size", 1: "seq_length"}, |
| "attention_mask": {0: "batch_size", 1: "seq_length"}, |
| "dimensions": {0: "batch_size"}, |
| } |
|
|
| output_names = ["embeddings"] |
| dynamic_axes_output = {"embeddings": {0: "batch_size", 1: "embedding_dim"}} |
|
|
| |
| onnx_exporter = OnnxLLMExporter( |
| onnx_model_dir=args.onnx_export_path, |
| model=model, |
| tokenizer=tokenizer, |
| ) |
|
|
| |
| onnx_exporter.export( |
| input_names=input_names, |
| output_names=output_names, |
| opset=args.onnx_opset, |
| dynamic_axes_input=dynamic_axes_input, |
| dynamic_axes_output=dynamic_axes_output, |
| export_dtype="fp32", |
| ) |
|
|
| |
| input_profiles = [ |
| { |
| "input_ids": [[1, 3], [16, 128], [64, 256]], |
| "attention_mask": [[1, 3], [16, 128], [64, 256]], |
| "dimensions": [[1], [16], [64]], |
| } |
| ] |
|
|
| |
| trt_builder_flags = None |
| if args.trt_version_compatible: |
| trt_builder_flags = [trt.BuilderFlag.VERSION_COMPATIBLE] |
|
|
| |
| override_layers_to_fp32 = [ |
| "/model/norm/", |
| "/pooling_module", |
| "/ReduceL2", |
| "/Div", |
| ] |
| |
| override_layernorm_precision_to_fp32 = True |
| profiling_verbosity = "layer_names_only" |
|
|
| |
| onnx_exporter.export_onnx_to_trt( |
| trt_model_dir=args.trt_model_path, |
| profiles=input_profiles, |
| override_layernorm_precision_to_fp32=override_layernorm_precision_to_fp32, |
| override_layers_to_fp32=override_layers_to_fp32, |
| profiling_verbosity=profiling_verbosity, |
| trt_builder_flags=trt_builder_flags, |
| ) |
|
|
| assert os.path.exists(args.trt_model_path) |
| assert os.path.exists(args.onnx_export_path) |
|
|
| prompt = ["hello", "world"] |
|
|
| prompt = onnx_exporter.get_tokenizer(prompt) |
| prompt["dimensions"] = [[2]] |
|
|
| output = onnx_exporter.forward(prompt) |
| if output is None: |
| logging.warning(f"Output is None because ONNX runtime is not installed.") |
|
|
|
|
| if __name__ == '__main__': |
| export_onnx_trt(get_args()) |
|
|