File size: 5,424 Bytes
90c4404 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | # Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via TensorRT.
Builds a TensorRT engine from the shipped
``embedl_all-MiniLM-L6-v2_int8.onnx`` artifact (Q/DQ nodes baked in by
embedl-deploy) and encodes a sentence into an L2-normalised
embedding. The first run caches the engine to
``embedl_all-MiniLM-L6-v2_int8.engine`` so reuse is fast.
Requires TensorRT >= 10.1, pycuda (or cuda-python), and transformers
(for the tokenizer). Tested on NVIDIA Jetson AGX Orin (JetPack 6)
and discrete GPUs with CUDA 12.
Usage::
python infer_trt.py --sentence "A man is eating food."
"""
import argparse
import time
from pathlib import Path
import numpy as np
import tensorrt as trt
from transformers import AutoTokenizer
try:
import pycuda.autoinit # noqa: F401 (initializes CUDA context)
import pycuda.driver as cuda
except ImportError as exc: # pragma: no cover
raise SystemExit(
"pycuda is required. Install with: pip install pycuda"
) from exc
ONNX_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.onnx")
ENGINE_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.engine")
TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2"
MAX_LENGTH = 128
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def build_engine() -> bytes:
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(ONNX_PATH, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError("ONNX parse failed.")
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)
config.builder_optimization_level = 5
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("Engine build failed.")
return bytes(serialized)
def load_or_build_engine() -> trt.ICudaEngine:
if ENGINE_PATH.exists():
data = ENGINE_PATH.read_bytes()
else:
print(f"Building engine (first run) → {ENGINE_PATH.name} …")
data = build_engine()
ENGINE_PATH.write_bytes(data)
runtime = trt.Runtime(TRT_LOGGER)
return runtime.deserialize_cuda_engine(data)
def tokenize(tokenizer, sentence: str):
enc = tokenizer(
sentence,
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
return_tensors="np",
)
return (
np.ascontiguousarray(enc["input_ids"].astype(np.int64)),
np.ascontiguousarray(enc["attention_mask"].astype(np.int64)),
)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--sentence", required=True, type=str)
args = parser.parse_args()
if not ONNX_PATH.exists():
raise SystemExit(
f"Expected {ONNX_PATH.name} next to this script. "
"Did you download the HF repo?"
)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
input_ids, attention_mask = tokenize(tokenizer, args.sentence)
engine = load_or_build_engine()
context = engine.create_execution_context()
# Resolve I/O tensor names by mode (input vs output) — order in
# the engine isn't guaranteed to match get_tensor_name(0..N).
input_names = []
output_names = []
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
mode = engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_names.append(name)
else:
output_names.append(name)
if len(input_names) != 2 or len(output_names) != 1:
raise RuntimeError(
f"Expected 2 inputs / 1 output, got "
f"{len(input_names)}/{len(output_names)}."
)
# Feed the inputs by canonical name so input_ids / attention_mask
# bind to the right tensor regardless of engine ordering.
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
out_shape = tuple(engine.get_tensor_shape(output_names[0]))
h_out = np.empty(out_shape, dtype=np.float32)
d_inputs = {}
for name in input_names:
arr = inputs[name]
d_inputs[name] = cuda.mem_alloc(arr.nbytes)
d_out = cuda.mem_alloc(h_out.nbytes)
stream = cuda.Stream()
for name in input_names:
cuda.memcpy_htod_async(d_inputs[name], inputs[name], stream)
context.set_tensor_address(name, int(d_inputs[name]))
context.set_tensor_address(output_names[0], int(d_out))
# Warm-up + timed run.
for _ in range(5):
context.execute_async_v3(stream.handle)
stream.synchronize()
t0 = time.perf_counter()
context.execute_async_v3(stream.handle)
stream.synchronize()
latency_ms = (time.perf_counter() - t0) * 1000.0
cuda.memcpy_dtoh_async(h_out, d_out, stream)
stream.synchronize()
embedding = h_out.reshape(-1)
first8 = ", ".join(f"{v:+.4f}" for v in embedding[:8])
print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
print(f"Sentence: {args.sentence!r}")
print(f"Embedding shape: {embedding.shape}")
print(f"First 8 dims: [{first8}]")
if __name__ == "__main__":
main()
|