# Copyright (C) 2026 Embedl AB """Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via TensorRT. Builds a TensorRT engine from the shipped ``embedl_all-MiniLM-L6-v2_int8.onnx`` artifact (Q/DQ nodes baked in by embedl-deploy) and encodes a sentence into an L2-normalised embedding. The first run caches the engine to ``embedl_all-MiniLM-L6-v2_int8.engine`` so reuse is fast. Requires TensorRT >= 10.1, pycuda (or cuda-python), and transformers (for the tokenizer). Tested on NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12. Usage:: python infer_trt.py --sentence "A man is eating food." """ import argparse import time from pathlib import Path import numpy as np import tensorrt as trt from transformers import AutoTokenizer try: import pycuda.autoinit # noqa: F401 (initializes CUDA context) import pycuda.driver as cuda except ImportError as exc: # pragma: no cover raise SystemExit( "pycuda is required. Install with: pip install pycuda" ) from exc ONNX_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.onnx") ENGINE_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.engine") TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2" MAX_LENGTH = 128 TRT_LOGGER = trt.Logger(trt.Logger.WARNING) def build_engine() -> bytes: builder = trt.Builder(TRT_LOGGER) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, TRT_LOGGER) with open(ONNX_PATH, "rb") as f: if not parser.parse(f.read()): for i in range(parser.num_errors): print(parser.get_error(i)) raise RuntimeError("ONNX parse failed.") config = builder.create_builder_config() config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.INT8) config.builder_optimization_level = 5 serialized = builder.build_serialized_network(network, config) if serialized is None: raise RuntimeError("Engine build failed.") return bytes(serialized) def load_or_build_engine() -> trt.ICudaEngine: if ENGINE_PATH.exists(): data = ENGINE_PATH.read_bytes() else: print(f"Building engine (first run) → {ENGINE_PATH.name} …") data = build_engine() ENGINE_PATH.write_bytes(data) runtime = trt.Runtime(TRT_LOGGER) return runtime.deserialize_cuda_engine(data) def tokenize(tokenizer, sentence: str): enc = tokenizer( sentence, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="np", ) return ( np.ascontiguousarray(enc["input_ids"].astype(np.int64)), np.ascontiguousarray(enc["attention_mask"].astype(np.int64)), ) def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--sentence", required=True, type=str) args = parser.parse_args() if not ONNX_PATH.exists(): raise SystemExit( f"Expected {ONNX_PATH.name} next to this script. " "Did you download the HF repo?" ) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID) input_ids, attention_mask = tokenize(tokenizer, args.sentence) engine = load_or_build_engine() context = engine.create_execution_context() # Resolve I/O tensor names by mode (input vs output) — order in # the engine isn't guaranteed to match get_tensor_name(0..N). input_names = [] output_names = [] for i in range(engine.num_io_tensors): name = engine.get_tensor_name(i) mode = engine.get_tensor_mode(name) if mode == trt.TensorIOMode.INPUT: input_names.append(name) else: output_names.append(name) if len(input_names) != 2 or len(output_names) != 1: raise RuntimeError( f"Expected 2 inputs / 1 output, got " f"{len(input_names)}/{len(output_names)}." ) # Feed the inputs by canonical name so input_ids / attention_mask # bind to the right tensor regardless of engine ordering. inputs = {"input_ids": input_ids, "attention_mask": attention_mask} out_shape = tuple(engine.get_tensor_shape(output_names[0])) h_out = np.empty(out_shape, dtype=np.float32) d_inputs = {} for name in input_names: arr = inputs[name] d_inputs[name] = cuda.mem_alloc(arr.nbytes) d_out = cuda.mem_alloc(h_out.nbytes) stream = cuda.Stream() for name in input_names: cuda.memcpy_htod_async(d_inputs[name], inputs[name], stream) context.set_tensor_address(name, int(d_inputs[name])) context.set_tensor_address(output_names[0], int(d_out)) # Warm-up + timed run. for _ in range(5): context.execute_async_v3(stream.handle) stream.synchronize() t0 = time.perf_counter() context.execute_async_v3(stream.handle) stream.synchronize() latency_ms = (time.perf_counter() - t0) * 1000.0 cuda.memcpy_dtoh_async(h_out, d_out, stream) stream.synchronize() embedding = h_out.reshape(-1) first8 = ", ".join(f"{v:+.4f}" for v in embedding[:8]) print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms") print(f"Sentence: {args.sentence!r}") print(f"Embedding shape: {embedding.shape}") print(f"First 8 dims: [{first8}]") if __name__ == "__main__": main()