File size: 6,364 Bytes
72ba9a3
 
17b6c94
72ba9a3
 
 
 
 
17b6c94
72ba9a3
 
17b6c94
72ba9a3
17b6c94
72ba9a3
 
17b6c94
 
 
72ba9a3
17b6c94
 
 
72ba9a3
 
 
 
 
 
 
 
 
 
 
 
17b6c94
72ba9a3
 
 
17b6c94
72ba9a3
17b6c94
 
72ba9a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b6c94
 
 
 
 
 
72ba9a3
 
17b6c94
 
 
72ba9a3
17b6c94
 
 
 
72ba9a3
 
17b6c94
72ba9a3
 
 
 
 
 
17b6c94
 
72ba9a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17b6c94
72ba9a3
 
 
 
17b6c94
 
 
 
72ba9a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl Chronos-2 INT8 forecaster via TensorRT.

Builds a TensorRT engine from the shipped
``embedl_chronos_2_ctx{512,2048}_int8.onnx`` artifact (Q/DQ nodes baked
in by embedl-deploy) and produces a 21-quantile forecast for a context
time series. The first run caches the engine to
``embedl_chronos_2_ctx{ctx}_int8.engine`` so reuse is fast.

Requires TensorRT >= 10.1, pycuda (or cuda-python), and numpy. Tested
on NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12.

Usage::

    python infer_trt.py --ctx 512                  # synthetic input
    python infer_trt.py --ctx 2048 --horizon 96    # longer history, custom horizon
"""

import argparse
import time
from pathlib import Path

import numpy as np
import tensorrt as trt

try:
    import pycuda.autoinit  # noqa: F401  (initializes CUDA context)
    import pycuda.driver as cuda
except ImportError as exc:  # pragma: no cover
    raise SystemExit(
        "pycuda is required. Install with: pip install pycuda"
    ) from exc

# chronos-2 emits 21 evenly spaced quantile levels along axis 1 of the
# output tensor. The median (q=0.5) is element 10.
MEDIAN_IDX = 10
NUM_OUTPUT_PATCHES = 64
OUTPUT_PATCH_SIZE = 16
MODEL_HORIZON = NUM_OUTPUT_PATCHES * OUTPUT_PATCH_SIZE  # 1024

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)


def build_engine(onnx_path: Path) -> bytes:
    builder = trt.Builder(TRT_LOGGER)
    network = builder.create_network(
        1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    parser = trt.OnnxParser(network, TRT_LOGGER)
    with open(onnx_path, "rb") as f:
        if not parser.parse(f.read()):
            for i in range(parser.num_errors):
                print(parser.get_error(i))
            raise RuntimeError("ONNX parse failed.")
    config = builder.create_builder_config()
    config.set_flag(trt.BuilderFlag.FP16)
    config.set_flag(trt.BuilderFlag.INT8)
    config.builder_optimization_level = 5
    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("Engine build failed.")
    return bytes(serialized)


def load_or_build_engine(
    onnx_path: Path, engine_path: Path,
) -> trt.ICudaEngine:
    if engine_path.exists():
        data = engine_path.read_bytes()
    else:
        print(f"Building engine (first run) → {engine_path.name} …")
        data = build_engine(onnx_path)
        engine_path.write_bytes(data)
    runtime = trt.Runtime(TRT_LOGGER)
    return runtime.deserialize_cuda_engine(data)


def make_synthetic_context(ctx_len: int) -> np.ndarray:
    """24h + 168h seasonal sine wave plus mild noise. Replace with
    your own series of length ``ctx_len``."""
    t = np.arange(ctx_len, dtype=np.float32)
    rng = np.random.RandomState(0)
    return (
        10.0 + 5.0 * np.sin(2 * np.pi * t / 24.0)
        + 2.0 * np.sin(2 * np.pi * t / 168.0)
        + 0.3 * rng.standard_normal(ctx_len).astype(np.float32)
    ).reshape(1, ctx_len).astype(np.float32)


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--ctx", type=int, choices=(512, 2048), default=512,
        help="Static context length of the artifact to use.",
    )
    parser.add_argument(
        "--horizon", type=int, default=48,
        help=f"How many steps of the median forecast to print "
        f"(model emits {MODEL_HORIZON}; capped here).",
    )
    args = parser.parse_args()
    if args.horizon > MODEL_HORIZON:
        raise SystemExit(f"--horizon must be <= {MODEL_HORIZON}")

    onnx_path = Path(__file__).with_name(
        f"embedl_chronos_2_ctx{args.ctx}_int8.onnx"
    )
    engine_path = onnx_path.with_suffix(".engine")

    if not onnx_path.exists():
        raise SystemExit(
            f"Expected {onnx_path.name} next to this script. "
            "Did you download the HF repo?"
        )

    context = make_synthetic_context(args.ctx)
    group_ids = np.zeros((1,), dtype=np.int64)

    engine = load_or_build_engine(onnx_path, engine_path)
    exec_context = engine.create_execution_context()

    # Resolve I/O tensor names by mode (input vs output) — order in the
    # engine isn't guaranteed to match get_tensor_name(0..N).
    input_names = []
    output_names = []
    for i in range(engine.num_io_tensors):
        name = engine.get_tensor_name(i)
        if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
            input_names.append(name)
        else:
            output_names.append(name)
    if len(input_names) != 2 or len(output_names) != 1:
        raise RuntimeError(
            f"Expected 2 inputs / 1 output, got "
            f"{len(input_names)} / {len(output_names)}."
        )

    # Bind by canonical name so context / group_ids land on the right
    # input tensor regardless of engine ordering.
    inputs = {"context": context, "group_ids": group_ids}

    out_shape = tuple(engine.get_tensor_shape(output_names[0]))
    h_out = np.empty(out_shape, dtype=np.float32)

    d_inputs = {
        name: cuda.mem_alloc(inputs[name].nbytes) for name in input_names
    }
    d_out = cuda.mem_alloc(h_out.nbytes)
    stream = cuda.Stream()

    for name in input_names:
        cuda.memcpy_htod_async(d_inputs[name], inputs[name], stream)
        exec_context.set_tensor_address(name, int(d_inputs[name]))
    exec_context.set_tensor_address(output_names[0], int(d_out))

    # Warm-up + timed run.
    for _ in range(5):
        exec_context.execute_async_v3(stream.handle)
    stream.synchronize()
    t0 = time.perf_counter()
    exec_context.execute_async_v3(stream.handle)
    stream.synchronize()
    latency_ms = (time.perf_counter() - t0) * 1000.0

    cuda.memcpy_dtoh_async(h_out, d_out, stream)
    stream.synchronize()

    # h_out shape: (1, 21, MODEL_HORIZON). Take the median quantile
    # (index MEDIAN_IDX) and clip to the requested horizon.
    median = h_out[0, MEDIAN_IDX, : args.horizon]
    np.set_printoptions(precision=3, suppress=True, linewidth=120)
    print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
    print(f"Context length:                    {args.ctx}")
    print(f"Output shape:                      {tuple(h_out.shape)}")
    print(f"Median forecast (first {args.horizon} steps):")
    print(median)


if __name__ == "__main__":
    main()