dann-od's picture
Minor fix to infer_pt2.py
9a4327b verified
# Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via torch.export.
Loads the shipped ``embedl_all-MiniLM-L6-v2_int8.pt2`` artifact with
``torch.export.load`` and encodes a sentence (or pair of sentences)
into an L2-normalised embedding. No TensorRT or ONNX runtime is
required — just PyTorch + transformers (for the tokenizer).
Usage::
python infer_pt2.py --sentence "A man is eating food."
python infer_pt2.py --sentence "A man is eating." \\
--sentence "A man is having a meal."
"""
import argparse
from pathlib import Path
import torch
from transformers import AutoTokenizer
PT2_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.pt2")
TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2"
MAX_LENGTH = 128
def encode(model: torch.nn.Module, tokenizer, sentence: str) -> torch.Tensor:
enc = tokenizer(
sentence,
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
return_tensors="pt",
)
with torch.no_grad():
embedding = model(enc["input_ids"], enc["attention_mask"])
return embedding.squeeze(0)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--sentence",
required=True,
action="append",
help="Sentence to encode. Pass twice to also print cosine similarity.",
)
args = parser.parse_args()
if not PT2_PATH.exists():
raise SystemExit(
f"Expected {PT2_PATH.name} next to this script. "
"Did you `huggingface-cli download` the repo?"
)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
# The ExportedProgram captured the model in eval mode at export
# time, so no further .eval() / no_grad toggling is needed (and
# neither is supported on the .module() wrapper).
model = torch.export.load(str(PT2_PATH)).module()
embeddings = [encode(model, tokenizer, s) for s in args.sentence]
for i, (sentence, emb) in enumerate(zip(args.sentence, embeddings), 1):
first8 = ", ".join(f"{v:+.4f}" for v in emb[:8].tolist())
print(f"[{i}] {sentence!r}")
print(f" embedding shape: {tuple(emb.shape)}")
print(f" first 8 dims: [{first8}]")
if len(embeddings) >= 2:
cos = torch.dot(embeddings[0], embeddings[1]).item()
print(f"\ncosine similarity (sentences 1 & 2): {cos:+.4f}")
if __name__ == "__main__":
main()