| |
| """Run inference on the Embedl Mobilevit Small INT8 model via TensorRT. |
| |
| This script builds a TensorRT engine from the shipped |
| ``embedl_mobilevit_small_int8.onnx`` artifact (Q/DQ nodes baked in by |
| embedl-deploy) and runs a single image through it. The first run |
| caches the engine to ``embedl_mobilevit_small_int8.engine`` so reuse is |
| fast. |
| |
| Requires TensorRT >= 10.1 and pycuda (or cuda-python). Tested on |
| NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12. |
| |
| Usage:: |
| |
| python infer_trt.py --image path/to/image.jpg |
| """ |
|
|
| import argparse |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import tensorrt as trt |
| from PIL import Image |
|
|
| try: |
| import pycuda.autoinit |
| import pycuda.driver as cuda |
| except ImportError as exc: |
| raise SystemExit( |
| "pycuda is required. Install with: pip install pycuda" |
| ) from exc |
|
|
| ONNX_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.onnx") |
| ENGINE_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.engine") |
| INPUT_SIZE = (256, 256) |
| MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32) |
| STD = np.array([1.0, 1.0, 1.0], dtype=np.float32) |
|
|
| TRT_LOGGER = trt.Logger(trt.Logger.WARNING) |
|
|
|
|
| def build_engine() -> bytes: |
| builder = trt.Builder(TRT_LOGGER) |
| network = builder.create_network( |
| 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| ) |
| parser = trt.OnnxParser(network, TRT_LOGGER) |
| with open(ONNX_PATH, "rb") as f: |
| if not parser.parse(f.read()): |
| for i in range(parser.num_errors): |
| print(parser.get_error(i)) |
| raise RuntimeError("ONNX parse failed.") |
| config = builder.create_builder_config() |
| config.set_flag(trt.BuilderFlag.FP16) |
| config.set_flag(trt.BuilderFlag.INT8) |
| config.builder_optimization_level = 5 |
| serialized = builder.build_serialized_network(network, config) |
| if serialized is None: |
| raise RuntimeError("Engine build failed.") |
| return bytes(serialized) |
|
|
|
|
| def load_or_build_engine() -> trt.ICudaEngine: |
| if ENGINE_PATH.exists(): |
| data = ENGINE_PATH.read_bytes() |
| else: |
| print(f"Building engine (first run) → {ENGINE_PATH.name} …") |
| data = build_engine() |
| ENGINE_PATH.write_bytes(data) |
| runtime = trt.Runtime(TRT_LOGGER) |
| return runtime.deserialize_cuda_engine(data) |
|
|
|
|
| def preprocess(image_path: Path) -> np.ndarray: |
| |
| |
| image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE) |
| arr = np.asarray(image, dtype=np.float32) / 255.0 |
| arr = (arr - MEAN) / STD |
| arr = arr[..., ::-1] |
| return np.ascontiguousarray(arr.transpose(2, 0, 1)[None]) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--image", required=True, type=Path) |
| parser.add_argument("--topk", type=int, default=5) |
| args = parser.parse_args() |
|
|
| if not ONNX_PATH.exists(): |
| raise SystemExit( |
| f"Expected {ONNX_PATH.name} next to this script. " |
| "Did you download the HF repo?" |
| ) |
|
|
| engine = load_or_build_engine() |
| context = engine.create_execution_context() |
|
|
| input_name = engine.get_tensor_name(0) |
| output_name = engine.get_tensor_name(1) |
| out_shape = tuple(engine.get_tensor_shape(output_name)) |
|
|
| x = preprocess(args.image) |
| h_out = np.empty(out_shape, dtype=np.float32) |
|
|
| d_in = cuda.mem_alloc(x.nbytes) |
| d_out = cuda.mem_alloc(h_out.nbytes) |
| stream = cuda.Stream() |
|
|
| cuda.memcpy_htod_async(d_in, x, stream) |
| context.set_tensor_address(input_name, int(d_in)) |
| context.set_tensor_address(output_name, int(d_out)) |
|
|
| |
| for _ in range(5): |
| context.execute_async_v3(stream.handle) |
| stream.synchronize() |
| t0 = time.perf_counter() |
| context.execute_async_v3(stream.handle) |
| stream.synchronize() |
| latency_ms = (time.perf_counter() - t0) * 1000.0 |
|
|
| cuda.memcpy_dtoh_async(h_out, d_out, stream) |
| stream.synchronize() |
|
|
| logits = h_out.reshape(-1) |
| probs = np.exp(logits - logits.max()) |
| probs /= probs.sum() |
| top = probs.argsort()[::-1][: args.topk] |
|
|
| print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms") |
| print(f"Top-{args.topk} predictions for {args.image}:") |
| for rank, idx in enumerate(top, 1): |
| print(f" {rank}. class {idx:>5d} ({probs[idx] * 100:5.2f}%)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|