File size: 4,631 Bytes
d31c7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl Mobilevit Small INT8 model via TensorRT.

This script builds a TensorRT engine from the shipped
``embedl_mobilevit_small_int8.onnx`` artifact (Q/DQ nodes baked in by
embedl-deploy) and runs a single image through it. The first run
caches the engine to ``embedl_mobilevit_small_int8.engine`` so reuse is
fast.

Requires TensorRT >= 10.1 and pycuda (or cuda-python). Tested on
NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12.

Usage::

    python infer_trt.py --image path/to/image.jpg
"""

import argparse
import time
from pathlib import Path

import numpy as np
import tensorrt as trt
from PIL import Image

try:
    import pycuda.autoinit  # noqa: F401  (initializes CUDA context)
    import pycuda.driver as cuda
except ImportError as exc:  # pragma: no cover
    raise SystemExit(
        "pycuda is required. Install with: pip install pycuda"
    ) from exc

ONNX_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.onnx")
ENGINE_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.engine")
INPUT_SIZE = (256, 256)
MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
STD = np.array([1.0, 1.0, 1.0], dtype=np.float32)

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)


def build_engine() -> bytes:
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    parser = trt.OnnxParser(network, TRT_LOGGER)
    with open(ONNX_PATH, "rb") as f:
        if not parser.parse(f.read()):
            for i in range(parser.num_errors):
                print(parser.get_error(i))
            raise RuntimeError("ONNX parse failed.")
    config = builder.create_builder_config()
    config.set_flag(trt.BuilderFlag.FP16)
    config.set_flag(trt.BuilderFlag.INT8)
    config.builder_optimization_level = 5
    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("Engine build failed.")
    return bytes(serialized)


def load_or_build_engine() -> trt.ICudaEngine:
    if ENGINE_PATH.exists():
        data = ENGINE_PATH.read_bytes()
    else:
        print(f"Building engine (first run) → {ENGINE_PATH.name} …")
        data = build_engine()
        ENGINE_PATH.write_bytes(data)
    runtime = trt.Runtime(TRT_LOGGER)
    return runtime.deserialize_cuda_engine(data)


def preprocess(image_path: Path) -> np.ndarray:
    # MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std
    # normalization (matches the upstream HF processor: do_normalize=None).
    image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE)
    arr = np.asarray(image, dtype=np.float32) / 255.0
    arr = (arr - MEAN) / STD
    arr = arr[..., ::-1]                 # RGB -> BGR
    return np.ascontiguousarray(arr.transpose(2, 0, 1)[None])


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--image", required=True, type=Path)
    parser.add_argument("--topk", type=int, default=5)
    args = parser.parse_args()

    if not ONNX_PATH.exists():
        raise SystemExit(
            f"Expected {ONNX_PATH.name} next to this script. "
            "Did you download the HF repo?"
        )

    engine = load_or_build_engine()
    context = engine.create_execution_context()

    input_name = engine.get_tensor_name(0)
    output_name = engine.get_tensor_name(1)
    out_shape = tuple(engine.get_tensor_shape(output_name))

    x = preprocess(args.image)
    h_out = np.empty(out_shape, dtype=np.float32)

    d_in = cuda.mem_alloc(x.nbytes)
    d_out = cuda.mem_alloc(h_out.nbytes)
    stream = cuda.Stream()

    cuda.memcpy_htod_async(d_in, x, stream)
    context.set_tensor_address(input_name, int(d_in))
    context.set_tensor_address(output_name, int(d_out))

    # Warm-up + timed run.
    for _ in range(5):
        context.execute_async_v3(stream.handle)
    stream.synchronize()
    t0 = time.perf_counter()
    context.execute_async_v3(stream.handle)
    stream.synchronize()
    latency_ms = (time.perf_counter() - t0) * 1000.0

    cuda.memcpy_dtoh_async(h_out, d_out, stream)
    stream.synchronize()

    logits = h_out.reshape(-1)
    probs = np.exp(logits - logits.max())
    probs /= probs.sum()
    top = probs.argsort()[::-1][: args.topk]

    print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
    print(f"Top-{args.topk} predictions for {args.image}:")
    for rank, idx in enumerate(top, 1):
        print(f"  {rank}. class {idx:>5d}  ({probs[idx] * 100:5.2f}%)")


if __name__ == "__main__":
    main()