# Copyright (C) 2026 Embedl AB """Run inference on the Embedl Chronos-2 INT8 forecaster via TensorRT. Builds a TensorRT engine from the shipped ``embedl_chronos_2_ctx{512,2048}_int8.onnx`` artifact (Q/DQ nodes baked in by embedl-deploy) and produces a 21-quantile forecast for a context time series. The first run caches the engine to ``embedl_chronos_2_ctx{ctx}_int8.engine`` so reuse is fast. Requires TensorRT >= 10.1, pycuda (or cuda-python), and numpy. Tested on NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12. Usage:: python infer_trt.py --ctx 512 # synthetic input python infer_trt.py --ctx 2048 --horizon 96 # longer history, custom horizon """ import argparse import time from pathlib import Path import numpy as np import tensorrt as trt try: import pycuda.autoinit # noqa: F401 (initializes CUDA context) import pycuda.driver as cuda except ImportError as exc: # pragma: no cover raise SystemExit( "pycuda is required. Install with: pip install pycuda" ) from exc # chronos-2 emits 21 evenly spaced quantile levels along axis 1 of the # output tensor. The median (q=0.5) is element 10. MEDIAN_IDX = 10 NUM_OUTPUT_PATCHES = 64 OUTPUT_PATCH_SIZE = 16 MODEL_HORIZON = NUM_OUTPUT_PATCHES * OUTPUT_PATCH_SIZE # 1024 TRT_LOGGER = trt.Logger(trt.Logger.WARNING) def build_engine(onnx_path: Path) -> bytes: builder = trt.Builder(TRT_LOGGER) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, TRT_LOGGER) with open(onnx_path, "rb") as f: if not parser.parse(f.read()): for i in range(parser.num_errors): print(parser.get_error(i)) raise RuntimeError("ONNX parse failed.") config = builder.create_builder_config() config.set_flag(trt.BuilderFlag.FP16) config.set_flag(trt.BuilderFlag.INT8) config.builder_optimization_level = 5 serialized = builder.build_serialized_network(network, config) if serialized is None: raise RuntimeError("Engine build failed.") return bytes(serialized) def load_or_build_engine( onnx_path: Path, engine_path: Path, ) -> trt.ICudaEngine: if engine_path.exists(): data = engine_path.read_bytes() else: print(f"Building engine (first run) → {engine_path.name} …") data = build_engine(onnx_path) engine_path.write_bytes(data) runtime = trt.Runtime(TRT_LOGGER) return runtime.deserialize_cuda_engine(data) def make_synthetic_context(ctx_len: int) -> np.ndarray: """24h + 168h seasonal sine wave plus mild noise. Replace with your own series of length ``ctx_len``.""" t = np.arange(ctx_len, dtype=np.float32) rng = np.random.RandomState(0) return ( 10.0 + 5.0 * np.sin(2 * np.pi * t / 24.0) + 2.0 * np.sin(2 * np.pi * t / 168.0) + 0.3 * rng.standard_normal(ctx_len).astype(np.float32) ).reshape(1, ctx_len).astype(np.float32) def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--ctx", type=int, choices=(512, 2048), default=512, help="Static context length of the artifact to use.", ) parser.add_argument( "--horizon", type=int, default=48, help=f"How many steps of the median forecast to print " f"(model emits {MODEL_HORIZON}; capped here).", ) args = parser.parse_args() if args.horizon > MODEL_HORIZON: raise SystemExit(f"--horizon must be <= {MODEL_HORIZON}") onnx_path = Path(__file__).with_name( f"embedl_chronos_2_ctx{args.ctx}_int8.onnx" ) engine_path = onnx_path.with_suffix(".engine") if not onnx_path.exists(): raise SystemExit( f"Expected {onnx_path.name} next to this script. " "Did you download the HF repo?" ) context = make_synthetic_context(args.ctx) group_ids = np.zeros((1,), dtype=np.int64) engine = load_or_build_engine(onnx_path, engine_path) exec_context = engine.create_execution_context() # Resolve I/O tensor names by mode (input vs output) — order in the # engine isn't guaranteed to match get_tensor_name(0..N). input_names = [] output_names = [] for i in range(engine.num_io_tensors): name = engine.get_tensor_name(i) if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: input_names.append(name) else: output_names.append(name) if len(input_names) != 2 or len(output_names) != 1: raise RuntimeError( f"Expected 2 inputs / 1 output, got " f"{len(input_names)} / {len(output_names)}." ) # Bind by canonical name so context / group_ids land on the right # input tensor regardless of engine ordering. inputs = {"context": context, "group_ids": group_ids} out_shape = tuple(engine.get_tensor_shape(output_names[0])) h_out = np.empty(out_shape, dtype=np.float32) d_inputs = { name: cuda.mem_alloc(inputs[name].nbytes) for name in input_names } d_out = cuda.mem_alloc(h_out.nbytes) stream = cuda.Stream() for name in input_names: cuda.memcpy_htod_async(d_inputs[name], inputs[name], stream) exec_context.set_tensor_address(name, int(d_inputs[name])) exec_context.set_tensor_address(output_names[0], int(d_out)) # Warm-up + timed run. for _ in range(5): exec_context.execute_async_v3(stream.handle) stream.synchronize() t0 = time.perf_counter() exec_context.execute_async_v3(stream.handle) stream.synchronize() latency_ms = (time.perf_counter() - t0) * 1000.0 cuda.memcpy_dtoh_async(h_out, d_out, stream) stream.synchronize() # h_out shape: (1, 21, MODEL_HORIZON). Take the median quantile # (index MEDIAN_IDX) and clip to the requested horizon. median = h_out[0, MEDIAN_IDX, : args.horizon] np.set_printoptions(precision=3, suppress=True, linewidth=120) print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms") print(f"Context length: {args.ctx}") print(f"Output shape: {tuple(h_out.shape)}") print(f"Median forecast (first {args.horizon} steps):") print(median) if __name__ == "__main__": main()