dann-od's picture
MobileVit-Small First commit
d31c7d4 verified
# Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl Mobilevit Small INT8 model via TensorRT.
This script builds a TensorRT engine from the shipped
``embedl_mobilevit_small_int8.onnx`` artifact (Q/DQ nodes baked in by
embedl-deploy) and runs a single image through it. The first run
caches the engine to ``embedl_mobilevit_small_int8.engine`` so reuse is
fast.
Requires TensorRT >= 10.1 and pycuda (or cuda-python). Tested on
NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12.
Usage::
python infer_trt.py --image path/to/image.jpg
"""
import argparse
import time
from pathlib import Path
import numpy as np
import tensorrt as trt
from PIL import Image
try:
import pycuda.autoinit # noqa: F401 (initializes CUDA context)
import pycuda.driver as cuda
except ImportError as exc: # pragma: no cover
raise SystemExit(
"pycuda is required. Install with: pip install pycuda"
) from exc
ONNX_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.onnx")
ENGINE_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.engine")
INPUT_SIZE = (256, 256)
MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
STD = np.array([1.0, 1.0, 1.0], dtype=np.float32)
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def build_engine() -> bytes:
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(ONNX_PATH, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError("ONNX parse failed.")
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)
config.builder_optimization_level = 5
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("Engine build failed.")
return bytes(serialized)
def load_or_build_engine() -> trt.ICudaEngine:
if ENGINE_PATH.exists():
data = ENGINE_PATH.read_bytes()
else:
print(f"Building engine (first run) → {ENGINE_PATH.name} …")
data = build_engine()
ENGINE_PATH.write_bytes(data)
runtime = trt.Runtime(TRT_LOGGER)
return runtime.deserialize_cuda_engine(data)
def preprocess(image_path: Path) -> np.ndarray:
# MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std
# normalization (matches the upstream HF processor: do_normalize=None).
image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE)
arr = np.asarray(image, dtype=np.float32) / 255.0
arr = (arr - MEAN) / STD
arr = arr[..., ::-1] # RGB -> BGR
return np.ascontiguousarray(arr.transpose(2, 0, 1)[None])
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--image", required=True, type=Path)
parser.add_argument("--topk", type=int, default=5)
args = parser.parse_args()
if not ONNX_PATH.exists():
raise SystemExit(
f"Expected {ONNX_PATH.name} next to this script. "
"Did you download the HF repo?"
)
engine = load_or_build_engine()
context = engine.create_execution_context()
input_name = engine.get_tensor_name(0)
output_name = engine.get_tensor_name(1)
out_shape = tuple(engine.get_tensor_shape(output_name))
x = preprocess(args.image)
h_out = np.empty(out_shape, dtype=np.float32)
d_in = cuda.mem_alloc(x.nbytes)
d_out = cuda.mem_alloc(h_out.nbytes)
stream = cuda.Stream()
cuda.memcpy_htod_async(d_in, x, stream)
context.set_tensor_address(input_name, int(d_in))
context.set_tensor_address(output_name, int(d_out))
# Warm-up + timed run.
for _ in range(5):
context.execute_async_v3(stream.handle)
stream.synchronize()
t0 = time.perf_counter()
context.execute_async_v3(stream.handle)
stream.synchronize()
latency_ms = (time.perf_counter() - t0) * 1000.0
cuda.memcpy_dtoh_async(h_out, d_out, stream)
stream.synchronize()
logits = h_out.reshape(-1)
probs = np.exp(logits - logits.max())
probs /= probs.sum()
top = probs.argsort()[::-1][: args.topk]
print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
print(f"Top-{args.topk} predictions for {args.image}:")
for rank, idx in enumerate(top, 1):
print(f" {rank}. class {idx:>5d} ({probs[idx] * 100:5.2f}%)")
if __name__ == "__main__":
main()