dann-od commited on
Commit
72ba9a3
·
verified ·
1 Parent(s): 07e5183

Reworked the infer_trt script

Browse files
Files changed (2) hide show
  1. README.md +5 -4
  2. infer_trt.py +145 -52
README.md CHANGED
@@ -91,14 +91,15 @@ forecasting and **ctx=2048** for long-history use cases.
91
  ## Quick Start
92
 
93
  ```bash
94
- pip install onnxruntime-gpu numpy
95
  python infer_trt.py --ctx 512 # 1.2× faster than FP16 on Orin
96
  python infer_trt.py --ctx 2048 # 1.3× faster than FP16 on Orin
97
  ```
98
 
99
- The `infer_trt.py` helper script creates a synthetic seasonal context
100
- for demonstration; replace it with your own series of the right
101
- length.
 
102
 
103
  ## Files
104
 
 
91
  ## Quick Start
92
 
93
  ```bash
94
+ pip install tensorrt pycuda numpy
95
  python infer_trt.py --ctx 512 # 1.2× faster than FP16 on Orin
96
  python infer_trt.py --ctx 2048 # 1.3× faster than FP16 on Orin
97
  ```
98
 
99
+ The `infer_trt.py` helper script builds a TensorRT engine from the
100
+ ONNX on first run (cached as `*.engine` next to the artifact) and
101
+ feeds a synthetic seasonal context for demonstration. Replace the
102
+ context generator with your own series of the right length.
103
 
104
  ## Files
105
 
infer_trt.py CHANGED
@@ -1,87 +1,180 @@
1
- #!/usr/bin/env python3
2
- """Run inference with embedl-deploy's INT8 chronos-2 on TensorRT.
3
 
4
- Reads a context series, runs the model, prints the median forecast.
5
- Uses ONNX Runtime's TensorrtExecutionProvider; falls back to CUDA / CPU
6
- if TRT isn't available.
 
 
7
 
8
- Usage::
 
9
 
10
- pip install onnxruntime-gpu numpy
11
- python infer_trt.py --ctx 512 # or --ctx 2048
12
 
13
- The script generates a synthetic seasonal context for demonstration;
14
- swap in your own series of the right length.
15
  """
16
- from __future__ import annotations
17
 
18
  import argparse
19
- import sys
20
  from pathlib import Path
21
 
22
  import numpy as np
23
- import onnxruntime as ort
24
-
25
- # chronos-2 emits 21 evenly spaced quantile levels along axis 1 of
26
- # the output. The median (q=0.5) is element 10.
 
 
 
 
 
 
 
 
27
  MEDIAN_IDX = 10
28
- NUM_OUTPUT_PATCHES = 64 # baked into the ONNX
29
- OUTPUT_PATCH_SIZE = 16 # baked into the ONNX
30
- MODEL_HORIZON = NUM_OUTPUT_PATCHES * OUTPUT_PATCH_SIZE # 1024 steps
31
 
 
32
 
33
- def _make_session(onnx_path: Path) -> ort.InferenceSession:
34
- providers = [
35
- ("TensorrtExecutionProvider", {"trt_int8_enable": True}),
36
- "CUDAExecutionProvider",
37
- "CPUExecutionProvider",
38
- ]
39
- return ort.InferenceSession(str(onnx_path), providers=providers)
40
 
41
-
42
- def main() -> int:
43
- parser = argparse.ArgumentParser()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  parser.add_argument(
45
  "--ctx", type=int, choices=(512, 2048), default=512,
46
  help="Static context length of the artifact to use.",
47
  )
48
  parser.add_argument(
49
  "--horizon", type=int, default=48,
50
- help="How many steps of the median forecast to print "
51
- f"(capped at MODEL_HORIZON={MODEL_HORIZON}).",
52
  )
53
  args = parser.parse_args()
54
  if args.horizon > MODEL_HORIZON:
55
- sys.exit(f"--horizon must be <= {MODEL_HORIZON}")
56
 
57
  onnx_path = Path(__file__).with_name(
58
  f"embedl_chronos_2_ctx{args.ctx}_int8.onnx"
59
  )
 
 
60
  if not onnx_path.exists():
61
- sys.exit(f"Missing {onnx_path}; run `huggingface-cli download` first.")
62
-
63
- # Synthetic seasonal context for demonstration.
64
- t = np.arange(args.ctx, dtype=np.float32)
65
- context = (
66
- 10.0 + 5.0 * np.sin(2 * np.pi * t / 24)
67
- + 2.0 * np.sin(2 * np.pi * t / 168)
68
- + 0.3 * np.random.RandomState(0).standard_normal(args.ctx).astype(np.float32)
69
- ).reshape(1, args.ctx).astype(np.float32)
70
  group_ids = np.zeros((1,), dtype=np.int64)
71
 
72
- session = _make_session(onnx_path)
73
- print(f"Providers in use: {session.get_providers()}")
74
- preds = session.run(
75
- None,
76
- {"context": context, "group_ids": group_ids},
77
- )[0]
78
- # preds shape: (1, 21, 1024)
79
- median = preds[0, MEDIAN_IDX, : args.horizon]
80
- print(f"Median forecast (first {args.horizon} steps):")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  np.set_printoptions(precision=3, suppress=True, linewidth=120)
 
 
 
 
82
  print(median)
83
- return 0
84
 
85
 
86
  if __name__ == "__main__":
87
- sys.exit(main())
 
1
+ # Copyright (C) 2026 Embedl AB
2
+ """Run inference on the Embedl Chronos-2 INT8 forecaster via TensorRT.
3
 
4
+ Builds a TensorRT engine from the shipped
5
+ ``embedl_chronos_2_ctx{512,2048}_int8.onnx`` artifact (Q/DQ nodes baked
6
+ in by embedl-deploy) and produces a 21-quantile forecast for a context
7
+ time series. The first run caches the engine to
8
+ ``embedl_chronos_2_ctx{ctx}_int8.engine`` so reuse is fast.
9
 
10
+ Requires TensorRT >= 10.1, pycuda (or cuda-python), and numpy. Tested
11
+ on NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12.
12
 
13
+ Usage::
 
14
 
15
+ python infer_trt.py --ctx 512 # synthetic input
16
+ python infer_trt.py --ctx 2048 --horizon 96 # longer history, custom horizon
17
  """
 
18
 
19
  import argparse
20
+ import time
21
  from pathlib import Path
22
 
23
  import numpy as np
24
+ import tensorrt as trt
25
+
26
+ try:
27
+ import pycuda.autoinit # noqa: F401 (initializes CUDA context)
28
+ import pycuda.driver as cuda
29
+ except ImportError as exc: # pragma: no cover
30
+ raise SystemExit(
31
+ "pycuda is required. Install with: pip install pycuda"
32
+ ) from exc
33
+
34
+ # chronos-2 emits 21 evenly spaced quantile levels along axis 1 of the
35
+ # output tensor. The median (q=0.5) is element 10.
36
  MEDIAN_IDX = 10
37
+ NUM_OUTPUT_PATCHES = 64
38
+ OUTPUT_PATCH_SIZE = 16
39
+ MODEL_HORIZON = NUM_OUTPUT_PATCHES * OUTPUT_PATCH_SIZE # 1024
40
 
41
+ TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
42
 
 
 
 
 
 
 
 
43
 
44
+ def build_engine(onnx_path: Path) -> bytes:
45
+ builder = trt.Builder(TRT_LOGGER)
46
+ network = builder.create_network(
47
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
48
+ )
49
+ parser = trt.OnnxParser(network, TRT_LOGGER)
50
+ with open(onnx_path, "rb") as f:
51
+ if not parser.parse(f.read()):
52
+ for i in range(parser.num_errors):
53
+ print(parser.get_error(i))
54
+ raise RuntimeError("ONNX parse failed.")
55
+ config = builder.create_builder_config()
56
+ config.set_flag(trt.BuilderFlag.FP16)
57
+ config.set_flag(trt.BuilderFlag.INT8)
58
+ config.builder_optimization_level = 5
59
+ serialized = builder.build_serialized_network(network, config)
60
+ if serialized is None:
61
+ raise RuntimeError("Engine build failed.")
62
+ return bytes(serialized)
63
+
64
+
65
+ def load_or_build_engine(
66
+ onnx_path: Path, engine_path: Path,
67
+ ) -> trt.ICudaEngine:
68
+ if engine_path.exists():
69
+ data = engine_path.read_bytes()
70
+ else:
71
+ print(f"Building engine (first run) → {engine_path.name} …")
72
+ data = build_engine(onnx_path)
73
+ engine_path.write_bytes(data)
74
+ runtime = trt.Runtime(TRT_LOGGER)
75
+ return runtime.deserialize_cuda_engine(data)
76
+
77
+
78
+ def make_synthetic_context(ctx_len: int) -> np.ndarray:
79
+ """24h + 168h seasonal sine wave plus mild noise. Replace with
80
+ your own series of length ``ctx_len``."""
81
+ t = np.arange(ctx_len, dtype=np.float32)
82
+ rng = np.random.RandomState(0)
83
+ return (
84
+ 10.0 + 5.0 * np.sin(2 * np.pi * t / 24.0)
85
+ + 2.0 * np.sin(2 * np.pi * t / 168.0)
86
+ + 0.3 * rng.standard_normal(ctx_len).astype(np.float32)
87
+ ).reshape(1, ctx_len).astype(np.float32)
88
+
89
+
90
+ def main() -> None:
91
+ parser = argparse.ArgumentParser(description=__doc__)
92
  parser.add_argument(
93
  "--ctx", type=int, choices=(512, 2048), default=512,
94
  help="Static context length of the artifact to use.",
95
  )
96
  parser.add_argument(
97
  "--horizon", type=int, default=48,
98
+ help=f"How many steps of the median forecast to print "
99
+ f"(model emits {MODEL_HORIZON}; capped here).",
100
  )
101
  args = parser.parse_args()
102
  if args.horizon > MODEL_HORIZON:
103
+ raise SystemExit(f"--horizon must be <= {MODEL_HORIZON}")
104
 
105
  onnx_path = Path(__file__).with_name(
106
  f"embedl_chronos_2_ctx{args.ctx}_int8.onnx"
107
  )
108
+ engine_path = onnx_path.with_suffix(".engine")
109
+
110
  if not onnx_path.exists():
111
+ raise SystemExit(
112
+ f"Expected {onnx_path.name} next to this script. "
113
+ "Did you download the HF repo?"
114
+ )
115
+
116
+ context = make_synthetic_context(args.ctx)
 
 
 
117
  group_ids = np.zeros((1,), dtype=np.int64)
118
 
119
+ engine = load_or_build_engine(onnx_path, engine_path)
120
+ exec_context = engine.create_execution_context()
121
+
122
+ # Resolve I/O tensor names by mode (input vs output) — order in the
123
+ # engine isn't guaranteed to match get_tensor_name(0..N).
124
+ input_names = []
125
+ output_names = []
126
+ for i in range(engine.num_io_tensors):
127
+ name = engine.get_tensor_name(i)
128
+ if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
129
+ input_names.append(name)
130
+ else:
131
+ output_names.append(name)
132
+ if len(input_names) != 2 or len(output_names) != 1:
133
+ raise RuntimeError(
134
+ f"Expected 2 inputs / 1 output, got "
135
+ f"{len(input_names)} / {len(output_names)}."
136
+ )
137
+
138
+ # Bind by canonical name so context / group_ids land on the right
139
+ # input tensor regardless of engine ordering.
140
+ inputs = {"context": context, "group_ids": group_ids}
141
+
142
+ out_shape = tuple(engine.get_tensor_shape(output_names[0]))
143
+ h_out = np.empty(out_shape, dtype=np.float32)
144
+
145
+ d_inputs = {
146
+ name: cuda.mem_alloc(inputs[name].nbytes) for name in input_names
147
+ }
148
+ d_out = cuda.mem_alloc(h_out.nbytes)
149
+ stream = cuda.Stream()
150
+
151
+ for name in input_names:
152
+ cuda.memcpy_htod_async(d_inputs[name], inputs[name], stream)
153
+ exec_context.set_tensor_address(name, int(d_inputs[name]))
154
+ exec_context.set_tensor_address(output_names[0], int(d_out))
155
+
156
+ # Warm-up + timed run.
157
+ for _ in range(5):
158
+ exec_context.execute_async_v3(stream.handle)
159
+ stream.synchronize()
160
+ t0 = time.perf_counter()
161
+ exec_context.execute_async_v3(stream.handle)
162
+ stream.synchronize()
163
+ latency_ms = (time.perf_counter() - t0) * 1000.0
164
+
165
+ cuda.memcpy_dtoh_async(h_out, d_out, stream)
166
+ stream.synchronize()
167
+
168
+ # h_out shape: (1, 21, MODEL_HORIZON). Take the median quantile
169
+ # (index MEDIAN_IDX) and clip to the requested horizon.
170
+ median = h_out[0, MEDIAN_IDX, : args.horizon]
171
  np.set_printoptions(precision=3, suppress=True, linewidth=120)
172
+ print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
173
+ print(f"Context length: {args.ctx}")
174
+ print(f"Output shape: {tuple(h_out.shape)}")
175
+ print(f"Median forecast (first {args.horizon} steps):")
176
  print(median)
 
177
 
178
 
179
  if __name__ == "__main__":
180
+ main()