dann-od commited on
Commit
90c4404
·
verified ·
1 Parent(s): 4759a9c

First verion of model card

Browse files
.gitattributes CHANGED
@@ -20,6 +20,7 @@
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
 
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
24
  *.pth filter=lfs diff=lfs merge=lfs -text
25
  *.rar filter=lfs diff=lfs merge=lfs -text
26
  *.safetensors filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,5 +1,174 @@
1
- ---
2
- license: other
3
- license_name: embedl-models-community-licence-1.0
4
- license_link: https://github.com/embedl/embedl-models/blob/main/LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: embedl-models-community-licence-1.0
4
+ license_link: https://github.com/embedl/embedl-models/blob/main/LICENSE
5
+ base_model:
6
+ - sentence-transformers/all-MiniLM-L6-v2
7
+ quantized_from:
8
+ - sentence-transformers/all-MiniLM-L6-v2
9
+ tags:
10
+ - sentence-similarity
11
+ - quantization
12
+ - onnx
13
+ - tensorrt
14
+ - edge
15
+ - embedl
16
+ gated: true
17
+ extra_gated_heading: "Access Embedl All Minilm L6 V2"
18
+ extra_gated_description: "To access this model, please review and accept the terms below. Your contact information is collected solely to manage access and, with your explicit consent, to notify you about updated or new optimized models from Embedl."
19
+ extra_gated_button_content: "Agree and request access"
20
+ extra_gated_prompt: "By requesting access you agree to the Embedl Models Community Licence and the upstream All Minilm L6 V2 License"
21
+ extra_gated_fields:
22
+ Company: text
23
+ I agree to the Embedl Models Community Licence and upstream All Minilm L6 V2 License: checkbox
24
+ I consent to being contacted by Embedl about products and services (optional): checkbox
25
+ ---
26
+ <!-- embedl-banner:start -->
27
+ <style>
28
+ .embedl-btn-primary { transition: background 160ms ease, box-shadow 160ms ease; }
29
+ .embedl-btn-primary:hover { background: #4FDCE4 !important; box-shadow: 0 8px 22px rgba(45,212,221,0.45) !important; }
30
+ .embedl-btn-secondary { transition: background 160ms ease; }
31
+ .embedl-btn-secondary:hover { background: rgba(45,212,221,0.15) !important; }
32
+ .embedl-headline { font-size: clamp(11px, 2.15vw, 15px) !important; }
33
+ .embedl-btn-primary, .embedl-btn-secondary {
34
+ font-size: clamp(11px, 1.65vw, 13px) !important;
35
+ padding: clamp(6px, 1.1vw, 9px) clamp(10px, 1.6vw, 14px) !important;
36
+ }
37
+ </style>
38
+ <div style="background:radial-gradient(600px 220px at 0% 50%,rgba(45,212,221,0.22) 0%,rgba(45,212,221,0) 60%),radial-gradient(400px 180px at 100% 100%,rgba(45,212,221,0.10) 0%,rgba(45,212,221,0) 55%),linear-gradient(135deg,#0B1626 0%,#142338 100%);border:1px solid rgba(45,212,221,0.28);border-radius:12px;padding:22px 24px;margin:0 0 24px 0;color:#F2F6FA;box-shadow:0 4px 16px rgba(11,22,38,0.18);overflow:hidden;box-sizing:border-box;max-width:100%;">
39
+ <table style="width:100%;border-collapse:collapse;border:0;background:transparent;">
40
+ <tr style="background:transparent;">
41
+ <td style="vertical-align:middle;border:0;padding:0;background:transparent;">
42
+ <div style="display:inline-block;font-size:10px;letter-spacing:0.08em;text-transform:uppercase;font-weight:700;color:#2DD4DD;background:rgba(45,212,221,0.15);border:1px solid rgba(45,212,221,0.35);padding:4px 10px;border-radius:999px;margin-bottom:10px;white-space:nowrap;">Optimized by Embedl</div>
43
+ <div class="embedl-headline" style="font-size:15px;font-weight:700;line-height:1.35;color:#F2F6FA;margin-bottom:4px;">Need to <span style="color:#2DD4DD;white-space:nowrap;">fine-tune</span>, hit <span style="color:#2DD4DD;white-space:nowrap;">performance targets</span>, or deploy on <span style="color:#2DD4DD;white-space:nowrap;">specific hardware</span>?</div>
44
+ <div style="font-size:13px;color:#9BA7B5;">We've got you covered.</div>
45
+ </td>
46
+ <td width="1%" style="vertical-align:middle;border:0;padding:0 0 0 18px;white-space:nowrap;text-align:right;background:transparent;">
47
+ <a href="https://www.embedl.com/models" class="embedl-btn-secondary" style="display:inline-block;font-size:13px;font-weight:600;padding:9px 14px;border-radius:6px;border:1px solid #2DD4DD;color:#2DD4DD;text-decoration:none;margin-right:8px;">Learn more</a>
48
+ <a href="https://www.embedl.com/contact" class="embedl-btn-primary" style="display:inline-block;font-size:13px;font-weight:600;padding:9px 14px;border-radius:6px;border:1px solid #2DD4DD;background:#2DD4DD;color:#0B1626;text-decoration:none;box-shadow:0 6px 18px rgba(45,212,221,0.28);">Get in touch →</a>
49
+ </td>
50
+ </tr>
51
+ </table>
52
+ </div>
53
+ <!-- embedl-banner:end -->
54
+
55
+ # Embedl All Minilm L6 V2 (Quantized for TensorRT)
56
+
57
+ Deployable INT8-quantized version of [`sentence-transformers/all-MiniLM-L6-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2),
58
+ optimized with [embedl-deploy](https://github.com/embedl/embedl-deploy)
59
+ for low-latency NVIDIA TensorRT inference on edge GPUs. Produces
60
+ the same L2-normalised sentence embedding as the upstream encoder,
61
+ in ~1/n the runtime.
62
+
63
+ ## Upstream Model
64
+
65
+ <a href="https://hfviewer.com/sentence-transformers/all-MiniLM-L6-v2?utm_source=huggingface&amp;utm_medium=embedded_model_card&amp;utm_campaign=sentence-transformers__all-MiniLM-L6-v2_card" target="_blank" rel="noopener">
66
+ <img
67
+ src="https://hfviewer.com/api/card.svg?source=sentence-transformers%2Fall-MiniLM-L6-v2&amp;v=20260501clipcard"
68
+ alt="Open sentence-transformers/all-MiniLM-L6-v2 in hfviewer"
69
+ width="100%"
70
+ />
71
+ </a>
72
+
73
+ ## Highlights
74
+
75
+ - **Mixed-precision INT8/FP16 quantization** with hardware-aware
76
+ optimizations from [embedl-deploy](https://github.com/embedl/embedl-deploy).
77
+ - **Drop-in replacement** for `sentence-transformers/all-MiniLM-L6-v2` in TensorRT pipelines —
78
+ same input pair (input_ids, attention_mask) at seq_len=128, same output embedding semantics
79
+ (mean-pooled, L2-normalised).
80
+ - **Validated accuracy** within 0.0026 of the FP32 Spearman ρ on stsb
81
+ (see Accuracy table below).
82
+ - **Faster than `trtexec --best`** on supported NVIDIA hardware (see Performance table below).
83
+ - Includes both **ONNX** (for TensorRT) and **PT2**
84
+ (`torch.export`-loadable) artifacts plus runnable inference scripts.
85
+
86
+ ## Quick Start
87
+
88
+ ```bash
89
+ pip install huggingface_hub transformers numpy
90
+ python -c "from huggingface_hub import snapshot_download; snapshot_download('embedl/all-MiniLM-L6-v2-quantized-trt', local_dir='.')"
91
+ python infer_pt2.py --sentence "A man is eating food." # pure PyTorch via torch.export
92
+ # or
93
+ python infer_trt.py --sentence "A man is eating food." # TensorRT (requires pycuda + tensorrt)
94
+ ```
95
+
96
+ ## Files
97
+
98
+ | File | Purpose |
99
+ |---|---|
100
+ | `embedl_all-MiniLM-L6-v2_int8.onnx` | INT8-quantized ONNX with Q/DQ nodes — feed to TensorRT. |
101
+ | `embedl_all-MiniLM-L6-v2_int8.pt2` | INT8-quantized `torch.export` ExportedProgram. |
102
+ | `infer_trt.py` | Build a TRT engine from the ONNX and run sample inference. |
103
+ | `infer_pt2.py` | Load the `.pt2` with `torch.export.load` and run sample inference. |
104
+
105
+ ## Performance
106
+
107
+ Latency measured with TensorRT + `trtexec`, GPU compute time only
108
+ (`--noDataTransfers`), CUDA Graph + Spin Wait enabled, clocks locked
109
+ (`nvpmodel -m 0 && jetson_clocks` on Jetson).
110
+
111
+ <img src="https://huggingface.co/datasets/embedl/documentation-images/resolve/main/all-MiniLM-L6-v2-quantized-trt/all-MiniLM-L6-v2-quantized-trt__orin-mountain-view.svg" alt="All Minilm L6 V2 benchmark on NVIDIA Jetson AGX Orin">
112
+
113
+ ### NVIDIA Jetson AGX Orin
114
+
115
+ | Configuration | Mean Latency | Speedup vs FP16 |
116
+ |---|---|---|
117
+ | TensorRT FP16 | 0.41 ms | 1.00x |
118
+ | TensorRT --best (unconstrained) | 0.41 ms | 1.01x |
119
+ | **Embedl Deploy INT8** | **0.38 ms** | **1.07x** |
120
+
121
+
122
+ ## Accuracy
123
+
124
+ Evaluated on the stsb validation split. The quantized model
125
+ retains nearly all of the FP32 accuracy with a small tolerance.
126
+
127
+ | Model | Spearman ρ |
128
+ |---|---|
129
+ | `sentence-transformers/all-MiniLM-L6-v2` FP32 (ours) | 0.8672 |
130
+ | **Embedl All Minilm L6 V2 INT8** | **0.8646** |
131
+
132
+ ## Creating Your Own Optimized Models
133
+
134
+ This artifact was produced with
135
+ [embedl-deploy](https://github.com/embedl/embedl-deploy),
136
+ Embedl's open-source PyTorch → TensorRT deployment library. You can
137
+ apply the same workflow to your own models — see
138
+ [the documentation](https://github.com/embedl/embedl-deploy#readme)
139
+ for installation and usage.
140
+
141
+ ## License
142
+
143
+ | Component | License |
144
+ |---|---|
145
+ | Optimized model artifacts (this repo) | [Embedl Models Community Licence v1.0](https://github.com/embedl/embedl-models/blob/main/LICENSE) — no redistribution as a hosted service |
146
+ | Upstream architecture and weights | [All Minilm L6 V2 License](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) |
147
+
148
+ ## Contact
149
+
150
+ We offer engineering support for on-prem/edge deployments and partner
151
+ co-marketing opportunities. Reach out at
152
+ [contact@embedl.com](mailto:contact@embedl.com), or open an issue on
153
+ [GitHub](https://github.com/embedl/embedl-deploy).
154
+
155
+ <!-- embedl-discord-banner:start -->
156
+ <style>
157
+ .embedl-discord-btn { transition: background 160ms ease, box-shadow 160ms ease; }
158
+ .embedl-discord-btn:hover { background: #6C77F5 !important; box-shadow: 0 8px 22px rgba(88,101,242,0.55) !important; }
159
+ </style>
160
+ <div style="background:radial-gradient(600px 220px at 0% 50%,rgba(88,101,242,0.22) 0%,rgba(88,101,242,0) 60%),radial-gradient(400px 180px at 100% 100%,rgba(88,101,242,0.10) 0%,rgba(88,101,242,0) 55%),linear-gradient(135deg,#0B1626 0%,#142338 100%);border:1px solid rgba(88,101,242,0.35);border-radius:12px;padding:22px 24px;margin:24px 0 0 0;color:#F2F6FA;box-shadow:0 4px 16px rgba(11,22,38,0.18);overflow:hidden;box-sizing:border-box;max-width:100%;">
161
+ <table style="width:100%;border-collapse:collapse;border:0;background:transparent;">
162
+ <tr style="background:transparent;">
163
+ <td style="vertical-align:middle;border:0;padding:0;background:transparent;">
164
+ <div style="display:inline-block;font-size:10px;letter-spacing:0.08em;text-transform:uppercase;font-weight:700;color:#A5B4FC;background:rgba(88,101,242,0.18);border:1px solid rgba(88,101,242,0.45);padding:4px 10px;border-radius:999px;margin-bottom:10px;white-space:nowrap;">Community &amp; support</div>
165
+ <div style="font-size:15px;font-weight:700;line-height:1.35;color:#F2F6FA;margin-bottom:4px;">Need help with this model? Chat with the Embedl team and other engineers on <span style="color:#A5B4FC;white-space:nowrap;">Discord</span>.</div>
166
+ <div style="font-size:13px;color:#9BA7B5;">Quantization gotchas, hardware questions, fine-tuning tips — bring them all.</div>
167
+ </td>
168
+ <td width="1%" style="vertical-align:middle;border:0;padding:0 0 0 18px;white-space:nowrap;text-align:right;background:transparent;">
169
+ <a href="https://discord.gg/MTbMWdKqE" class="embedl-discord-btn" style="display:inline-block;font-size:13px;font-weight:600;padding:9px 14px;border-radius:6px;border:1px solid #5865F2;background:#5865F2;color:#FFFFFF;text-decoration:none;box-shadow:0 6px 18px rgba(88,101,242,0.35);">Join our Discord →</a>
170
+ </td>
171
+ </tr>
172
+ </table>
173
+ </div>
174
+ <!-- embedl-discord-banner:end -->
embedl_all-MiniLM-L6-v2_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:693d54efc1f18fa72f6bb6b9245cea05efb5d0c2e3b37c41d5c0e438d7edc5bb
3
+ size 89988793
embedl_all-MiniLM-L6-v2_int8.pt2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f24d335c16b7a94740668484f2f36021b1fd047df3bb50b41da18cf5f122251
3
+ size 134547563
infer_pt2.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2026 Embedl AB
2
+ """Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via torch.export.
3
+
4
+ Loads the shipped ``embedl_all-MiniLM-L6-v2_int8.pt2`` artifact with
5
+ ``torch.export.load`` and encodes a sentence (or pair of sentences)
6
+ into an L2-normalised embedding. No TensorRT or ONNX runtime is
7
+ required — just PyTorch + transformers (for the tokenizer).
8
+
9
+ Usage::
10
+
11
+ python infer_pt2.py --sentence "A man is eating food."
12
+ python infer_pt2.py --sentence "A man is eating." \\
13
+ --sentence "A man is having a meal."
14
+ """
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ from transformers import AutoTokenizer
21
+
22
+ PT2_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.pt2")
23
+ TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2"
24
+ MAX_LENGTH = 128
25
+
26
+
27
+ def encode(model: torch.nn.Module, tokenizer, sentence: str) -> torch.Tensor:
28
+ enc = tokenizer(
29
+ sentence,
30
+ padding="max_length",
31
+ truncation=True,
32
+ max_length=MAX_LENGTH,
33
+ return_tensors="pt",
34
+ )
35
+ with torch.no_grad():
36
+ embedding = model(enc["input_ids"], enc["attention_mask"])
37
+ return embedding.squeeze(0)
38
+
39
+
40
+ def main() -> None:
41
+ parser = argparse.ArgumentParser(description=__doc__)
42
+ parser.add_argument(
43
+ "--sentence",
44
+ required=True,
45
+ action="append",
46
+ help="Sentence to encode. Pass twice to also print cosine similarity.",
47
+ )
48
+ args = parser.parse_args()
49
+
50
+ if not PT2_PATH.exists():
51
+ raise SystemExit(
52
+ f"Expected {PT2_PATH.name} next to this script. "
53
+ "Did you `huggingface-cli download` the repo?"
54
+ )
55
+
56
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
57
+ # The ExportedProgram captured the model in eval mode at export
58
+ # time, so no further .eval() / no_grad toggling is needed (and
59
+ # neither is supported on the .module() wrapper).
60
+ model = torch.export.load(str(PT2_PATH)).module()
61
+
62
+ embeddings = [encode(model, tokenizer, s) for s in args.sentence]
63
+
64
+ for i, (sentence, emb) in enumerate(zip(args.sentence, embeddings), 1):
65
+ first8 = ", ".join(f"{v:+.4f}" for v in emb[:8].tolist())
66
+ print(f"[{i}] {sentence!r}")
67
+ print(f" embedding shape: {tuple(emb.shape)}")
68
+ print(f" first 8 dims: [{first8}]")
69
+
70
+ if len(embeddings) >= 2:
71
+ cos = torch.dot(embeddings[0], embeddings[1]).item()
72
+ print(f"\\ncosine similarity (sentences 1 & 2): {cos:+.4f}")
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
infer_trt.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2026 Embedl AB
2
+ """Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via TensorRT.
3
+
4
+ Builds a TensorRT engine from the shipped
5
+ ``embedl_all-MiniLM-L6-v2_int8.onnx`` artifact (Q/DQ nodes baked in by
6
+ embedl-deploy) and encodes a sentence into an L2-normalised
7
+ embedding. The first run caches the engine to
8
+ ``embedl_all-MiniLM-L6-v2_int8.engine`` so reuse is fast.
9
+
10
+ Requires TensorRT >= 10.1, pycuda (or cuda-python), and transformers
11
+ (for the tokenizer). Tested on NVIDIA Jetson AGX Orin (JetPack 6)
12
+ and discrete GPUs with CUDA 12.
13
+
14
+ Usage::
15
+
16
+ python infer_trt.py --sentence "A man is eating food."
17
+ """
18
+
19
+ import argparse
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import tensorrt as trt
25
+ from transformers import AutoTokenizer
26
+
27
+ try:
28
+ import pycuda.autoinit # noqa: F401 (initializes CUDA context)
29
+ import pycuda.driver as cuda
30
+ except ImportError as exc: # pragma: no cover
31
+ raise SystemExit(
32
+ "pycuda is required. Install with: pip install pycuda"
33
+ ) from exc
34
+
35
+ ONNX_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.onnx")
36
+ ENGINE_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.engine")
37
+ TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2"
38
+ MAX_LENGTH = 128
39
+
40
+ TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
41
+
42
+
43
+ def build_engine() -> bytes:
44
+ builder = trt.Builder(TRT_LOGGER)
45
+ network = builder.create_network(
46
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
47
+ )
48
+ parser = trt.OnnxParser(network, TRT_LOGGER)
49
+ with open(ONNX_PATH, "rb") as f:
50
+ if not parser.parse(f.read()):
51
+ for i in range(parser.num_errors):
52
+ print(parser.get_error(i))
53
+ raise RuntimeError("ONNX parse failed.")
54
+ config = builder.create_builder_config()
55
+ config.set_flag(trt.BuilderFlag.FP16)
56
+ config.set_flag(trt.BuilderFlag.INT8)
57
+ config.builder_optimization_level = 5
58
+ serialized = builder.build_serialized_network(network, config)
59
+ if serialized is None:
60
+ raise RuntimeError("Engine build failed.")
61
+ return bytes(serialized)
62
+
63
+
64
+ def load_or_build_engine() -> trt.ICudaEngine:
65
+ if ENGINE_PATH.exists():
66
+ data = ENGINE_PATH.read_bytes()
67
+ else:
68
+ print(f"Building engine (first run) → {ENGINE_PATH.name} …")
69
+ data = build_engine()
70
+ ENGINE_PATH.write_bytes(data)
71
+ runtime = trt.Runtime(TRT_LOGGER)
72
+ return runtime.deserialize_cuda_engine(data)
73
+
74
+
75
+ def tokenize(tokenizer, sentence: str):
76
+ enc = tokenizer(
77
+ sentence,
78
+ padding="max_length",
79
+ truncation=True,
80
+ max_length=MAX_LENGTH,
81
+ return_tensors="np",
82
+ )
83
+ return (
84
+ np.ascontiguousarray(enc["input_ids"].astype(np.int64)),
85
+ np.ascontiguousarray(enc["attention_mask"].astype(np.int64)),
86
+ )
87
+
88
+
89
+ def main() -> None:
90
+ parser = argparse.ArgumentParser(description=__doc__)
91
+ parser.add_argument("--sentence", required=True, type=str)
92
+ args = parser.parse_args()
93
+
94
+ if not ONNX_PATH.exists():
95
+ raise SystemExit(
96
+ f"Expected {ONNX_PATH.name} next to this script. "
97
+ "Did you download the HF repo?"
98
+ )
99
+
100
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
101
+ input_ids, attention_mask = tokenize(tokenizer, args.sentence)
102
+
103
+ engine = load_or_build_engine()
104
+ context = engine.create_execution_context()
105
+
106
+ # Resolve I/O tensor names by mode (input vs output) — order in
107
+ # the engine isn't guaranteed to match get_tensor_name(0..N).
108
+ input_names = []
109
+ output_names = []
110
+ for i in range(engine.num_io_tensors):
111
+ name = engine.get_tensor_name(i)
112
+ mode = engine.get_tensor_mode(name)
113
+ if mode == trt.TensorIOMode.INPUT:
114
+ input_names.append(name)
115
+ else:
116
+ output_names.append(name)
117
+ if len(input_names) != 2 or len(output_names) != 1:
118
+ raise RuntimeError(
119
+ f"Expected 2 inputs / 1 output, got "
120
+ f"{len(input_names)}/{len(output_names)}."
121
+ )
122
+
123
+ # Feed the inputs by canonical name so input_ids / attention_mask
124
+ # bind to the right tensor regardless of engine ordering.
125
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
126
+
127
+ out_shape = tuple(engine.get_tensor_shape(output_names[0]))
128
+ h_out = np.empty(out_shape, dtype=np.float32)
129
+
130
+ d_inputs = {}
131
+ for name in input_names:
132
+ arr = inputs[name]
133
+ d_inputs[name] = cuda.mem_alloc(arr.nbytes)
134
+ d_out = cuda.mem_alloc(h_out.nbytes)
135
+ stream = cuda.Stream()
136
+
137
+ for name in input_names:
138
+ cuda.memcpy_htod_async(d_inputs[name], inputs[name], stream)
139
+ context.set_tensor_address(name, int(d_inputs[name]))
140
+ context.set_tensor_address(output_names[0], int(d_out))
141
+
142
+ # Warm-up + timed run.
143
+ for _ in range(5):
144
+ context.execute_async_v3(stream.handle)
145
+ stream.synchronize()
146
+ t0 = time.perf_counter()
147
+ context.execute_async_v3(stream.handle)
148
+ stream.synchronize()
149
+ latency_ms = (time.perf_counter() - t0) * 1000.0
150
+
151
+ cuda.memcpy_dtoh_async(h_out, d_out, stream)
152
+ stream.synchronize()
153
+
154
+ embedding = h_out.reshape(-1)
155
+ first8 = ", ".join(f"{v:+.4f}" for v in embedding[:8])
156
+ print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
157
+ print(f"Sentence: {args.sentence!r}")
158
+ print(f"Embedding shape: {embedding.shape}")
159
+ print(f"First 8 dims: [{first8}]")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()