File size: 4,631 Bytes
d31c7d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | # Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl Mobilevit Small INT8 model via TensorRT.
This script builds a TensorRT engine from the shipped
``embedl_mobilevit_small_int8.onnx`` artifact (Q/DQ nodes baked in by
embedl-deploy) and runs a single image through it. The first run
caches the engine to ``embedl_mobilevit_small_int8.engine`` so reuse is
fast.
Requires TensorRT >= 10.1 and pycuda (or cuda-python). Tested on
NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12.
Usage::
python infer_trt.py --image path/to/image.jpg
"""
import argparse
import time
from pathlib import Path
import numpy as np
import tensorrt as trt
from PIL import Image
try:
import pycuda.autoinit # noqa: F401 (initializes CUDA context)
import pycuda.driver as cuda
except ImportError as exc: # pragma: no cover
raise SystemExit(
"pycuda is required. Install with: pip install pycuda"
) from exc
ONNX_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.onnx")
ENGINE_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.engine")
INPUT_SIZE = (256, 256)
MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
STD = np.array([1.0, 1.0, 1.0], dtype=np.float32)
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def build_engine() -> bytes:
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, TRT_LOGGER)
with open(ONNX_PATH, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
print(parser.get_error(i))
raise RuntimeError("ONNX parse failed.")
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.INT8)
config.builder_optimization_level = 5
serialized = builder.build_serialized_network(network, config)
if serialized is None:
raise RuntimeError("Engine build failed.")
return bytes(serialized)
def load_or_build_engine() -> trt.ICudaEngine:
if ENGINE_PATH.exists():
data = ENGINE_PATH.read_bytes()
else:
print(f"Building engine (first run) → {ENGINE_PATH.name} …")
data = build_engine()
ENGINE_PATH.write_bytes(data)
runtime = trt.Runtime(TRT_LOGGER)
return runtime.deserialize_cuda_engine(data)
def preprocess(image_path: Path) -> np.ndarray:
# MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std
# normalization (matches the upstream HF processor: do_normalize=None).
image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE)
arr = np.asarray(image, dtype=np.float32) / 255.0
arr = (arr - MEAN) / STD
arr = arr[..., ::-1] # RGB -> BGR
return np.ascontiguousarray(arr.transpose(2, 0, 1)[None])
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--image", required=True, type=Path)
parser.add_argument("--topk", type=int, default=5)
args = parser.parse_args()
if not ONNX_PATH.exists():
raise SystemExit(
f"Expected {ONNX_PATH.name} next to this script. "
"Did you download the HF repo?"
)
engine = load_or_build_engine()
context = engine.create_execution_context()
input_name = engine.get_tensor_name(0)
output_name = engine.get_tensor_name(1)
out_shape = tuple(engine.get_tensor_shape(output_name))
x = preprocess(args.image)
h_out = np.empty(out_shape, dtype=np.float32)
d_in = cuda.mem_alloc(x.nbytes)
d_out = cuda.mem_alloc(h_out.nbytes)
stream = cuda.Stream()
cuda.memcpy_htod_async(d_in, x, stream)
context.set_tensor_address(input_name, int(d_in))
context.set_tensor_address(output_name, int(d_out))
# Warm-up + timed run.
for _ in range(5):
context.execute_async_v3(stream.handle)
stream.synchronize()
t0 = time.perf_counter()
context.execute_async_v3(stream.handle)
stream.synchronize()
latency_ms = (time.perf_counter() - t0) * 1000.0
cuda.memcpy_dtoh_async(h_out, d_out, stream)
stream.synchronize()
logits = h_out.reshape(-1)
probs = np.exp(logits - logits.max())
probs /= probs.sum()
top = probs.argsort()[::-1][: args.topk]
print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
print(f"Top-{args.topk} predictions for {args.image}:")
for rank, idx in enumerate(top, 1):
print(f" {rank}. class {idx:>5d} ({probs[idx] * 100:5.2f}%)")
if __name__ == "__main__":
main()
|