dann-od's picture
First verion of model card
90c4404 verified
# Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via TensorRT.
Builds a TensorRT engine from the shipped
``embedl_all-MiniLM-L6-v2_int8.onnx`` artifact (Q/DQ nodes baked in by
embedl-deploy) and encodes a sentence into an L2-normalised
embedding. The first run caches the engine to
``embedl_all-MiniLM-L6-v2_int8.engine`` so reuse is fast.
Requires TensorRT >= 10.1, pycuda (or cuda-python), and transformers
(for the tokenizer). Tested on NVIDIA Jetson AGX Orin (JetPack 6)
and discrete GPUs with CUDA 12.
Usage::
python infer_trt.py --sentence "A man is eating food."
"""
import argparse
import time
from pathlib import Path
import numpy as np
import tensorrt as trt
from transformers import AutoTokenizer
try:
import pycuda.autoinit # noqa: F401 (initializes CUDA context)
import pycuda.driver as cuda
except ImportError as exc: # pragma: no cover
raise SystemExit(
"pycuda is required. Install with: pip install pycuda"
) from exc
ONNX_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.onnx")
ENGINE_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.engine")
TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2"
MAX_LENGTH = 128
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def build_engine() -> bytes:
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(ONNX_PATH, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError("ONNX parse failed.")
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)
config.builder_optimization_level = 5
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("Engine build failed.")
return bytes(serialized)
def load_or_build_engine() -> trt.ICudaEngine:
if ENGINE_PATH.exists():
data = ENGINE_PATH.read_bytes()
else:
print(f"Building engine (first run) → {ENGINE_PATH.name} …")
data = build_engine()
ENGINE_PATH.write_bytes(data)
runtime = trt.Runtime(TRT_LOGGER)
return runtime.deserialize_cuda_engine(data)
def tokenize(tokenizer, sentence: str):
enc = tokenizer(
sentence,
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
return_tensors="np",
)
return (
np.ascontiguousarray(enc["input_ids"].astype(np.int64)),
np.ascontiguousarray(enc["attention_mask"].astype(np.int64)),
)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--sentence", required=True, type=str)
args = parser.parse_args()
if not ONNX_PATH.exists():
raise SystemExit(
f"Expected {ONNX_PATH.name} next to this script. "
"Did you download the HF repo?"
)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
input_ids, attention_mask = tokenize(tokenizer, args.sentence)
engine = load_or_build_engine()
context = engine.create_execution_context()
# Resolve I/O tensor names by mode (input vs output) — order in
# the engine isn't guaranteed to match get_tensor_name(0..N).
input_names = []
output_names = []
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
mode = engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_names.append(name)
else:
output_names.append(name)
if len(input_names) != 2 or len(output_names) != 1:
raise RuntimeError(
f"Expected 2 inputs / 1 output, got "
f"{len(input_names)}/{len(output_names)}."
)
# Feed the inputs by canonical name so input_ids / attention_mask
# bind to the right tensor regardless of engine ordering.
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
out_shape = tuple(engine.get_tensor_shape(output_names[0]))
h_out = np.empty(out_shape, dtype=np.float32)
d_inputs = {}
for name in input_names:
arr = inputs[name]
d_inputs[name] = cuda.mem_alloc(arr.nbytes)
d_out = cuda.mem_alloc(h_out.nbytes)
stream = cuda.Stream()
for name in input_names:
cuda.memcpy_htod_async(d_inputs[name], inputs[name], stream)
context.set_tensor_address(name, int(d_inputs[name]))
context.set_tensor_address(output_names[0], int(d_out))
# Warm-up + timed run.
for _ in range(5):
context.execute_async_v3(stream.handle)
stream.synchronize()
t0 = time.perf_counter()
context.execute_async_v3(stream.handle)
stream.synchronize()
latency_ms = (time.perf_counter() - t0) * 1000.0
cuda.memcpy_dtoh_async(h_out, d_out, stream)
stream.synchronize()
embedding = h_out.reshape(-1)
first8 = ", ".join(f"{v:+.4f}" for v in embedding[:8])
print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
print(f"Sentence: {args.sentence!r}")
print(f"Embedding shape: {embedding.shape}")
print(f"First 8 dims: [{first8}]")
if __name__ == "__main__":
main()