# Copyright (C) 2026 Embedl AB """Run inference on the Embedl Mobilevit Small INT8 model via TensorRT. This script builds a TensorRT engine from the shipped ``embedl_mobilevit_small_int8.onnx`` artifact (Q/DQ nodes baked in by embedl-deploy) and runs a single image through it. The first run caches the engine to ``embedl_mobilevit_small_int8.engine`` so reuse is fast. Requires TensorRT >= 10.1 and pycuda (or cuda-python). Tested on NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12. Usage:: python infer_trt.py --image path/to/image.jpg """ import argparse import time from pathlib import Path import numpy as np import tensorrt as trt from PIL import Image try: import pycuda.autoinit # noqa: F401 (initializes CUDA context) import pycuda.driver as cuda except ImportError as exc: # pragma: no cover raise SystemExit( "pycuda is required. Install with: pip install pycuda" ) from exc ONNX_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.onnx") ENGINE_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.engine") INPUT_SIZE = (256, 256) MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32) STD = np.array([1.0, 1.0, 1.0], dtype=np.float32) TRT_LOGGER = trt.Logger(trt.Logger.WARNING) def build_engine() -> bytes: builder = trt.Builder(TRT_LOGGER) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, TRT_LOGGER) with open(ONNX_PATH, "rb") as f: if not parser.parse(f.read()): for i in range(parser.num_errors): print(parser.get_error(i)) raise RuntimeError("ONNX parse failed.") config = builder.create_builder_config() config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.INT8) config.builder_optimization_level = 5 serialized = builder.build_serialized_network(network, config) if serialized is None: raise RuntimeError("Engine build failed.") return bytes(serialized) def load_or_build_engine() -> trt.ICudaEngine: if ENGINE_PATH.exists(): data = ENGINE_PATH.read_bytes() else: print(f"Building engine (first run) → {ENGINE_PATH.name} …") data = build_engine() ENGINE_PATH.write_bytes(data) runtime = trt.Runtime(TRT_LOGGER) return runtime.deserialize_cuda_engine(data) def preprocess(image_path: Path) -> np.ndarray: # MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std # normalization (matches the upstream HF processor: do_normalize=None). image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE) arr = np.asarray(image, dtype=np.float32) / 255.0 arr = (arr - MEAN) / STD arr = arr[..., ::-1] # RGB -> BGR return np.ascontiguousarray(arr.transpose(2, 0, 1)[None]) def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--image", required=True, type=Path) parser.add_argument("--topk", type=int, default=5) args = parser.parse_args() if not ONNX_PATH.exists(): raise SystemExit( f"Expected {ONNX_PATH.name} next to this script. " "Did you download the HF repo?" ) engine = load_or_build_engine() context = engine.create_execution_context() input_name = engine.get_tensor_name(0) output_name = engine.get_tensor_name(1) out_shape = tuple(engine.get_tensor_shape(output_name)) x = preprocess(args.image) h_out = np.empty(out_shape, dtype=np.float32) d_in = cuda.mem_alloc(x.nbytes) d_out = cuda.mem_alloc(h_out.nbytes) stream = cuda.Stream() cuda.memcpy_htod_async(d_in, x, stream) context.set_tensor_address(input_name, int(d_in)) context.set_tensor_address(output_name, int(d_out)) # Warm-up + timed run. for _ in range(5): context.execute_async_v3(stream.handle) stream.synchronize() t0 = time.perf_counter() context.execute_async_v3(stream.handle) stream.synchronize() latency_ms = (time.perf_counter() - t0) * 1000.0 cuda.memcpy_dtoh_async(h_out, d_out, stream) stream.synchronize() logits = h_out.reshape(-1) probs = np.exp(logits - logits.max()) probs /= probs.sum() top = probs.argsort()[::-1][: args.topk] print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms") print(f"Top-{args.topk} predictions for {args.image}:") for rank, idx in enumerate(top, 1): print(f" {rank}. class {idx:>5d} ({probs[idx] * 100:5.2f}%)") if __name__ == "__main__": main()