# Copyright (C) 2026 Embedl AB """Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via torch.export. Loads the shipped ``embedl_all-MiniLM-L6-v2_int8.pt2`` artifact with ``torch.export.load`` and encodes a sentence (or pair of sentences) into an L2-normalised embedding. No TensorRT or ONNX runtime is required — just PyTorch + transformers (for the tokenizer). Usage:: python infer_pt2.py --sentence "A man is eating food." python infer_pt2.py --sentence "A man is eating." \\ --sentence "A man is having a meal." """ import argparse from pathlib import Path import torch from transformers import AutoTokenizer PT2_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.pt2") TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2" MAX_LENGTH = 128 def encode(model: torch.nn.Module, tokenizer, sentence: str) -> torch.Tensor: enc = tokenizer( sentence, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt", ) with torch.no_grad(): embedding = model(enc["input_ids"], enc["attention_mask"]) return embedding.squeeze(0) def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--sentence", required=True, action="append", help="Sentence to encode. Pass twice to also print cosine similarity.", ) args = parser.parse_args() if not PT2_PATH.exists(): raise SystemExit( f"Expected {PT2_PATH.name} next to this script. " "Did you `huggingface-cli download` the repo?" ) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID) # The ExportedProgram captured the model in eval mode at export # time, so no further .eval() / no_grad toggling is needed (and # neither is supported on the .module() wrapper). model = torch.export.load(str(PT2_PATH)).module() embeddings = [encode(model, tokenizer, s) for s in args.sentence] for i, (sentence, emb) in enumerate(zip(args.sentence, embeddings), 1): first8 = ", ".join(f"{v:+.4f}" for v in emb[:8].tolist()) print(f"[{i}] {sentence!r}") print(f" embedding shape: {tuple(emb.shape)}") print(f" first 8 dims: [{first8}]") if len(embeddings) >= 2: cos = torch.dot(embeddings[0], embeddings[1]).item() print(f"\ncosine similarity (sentences 1 & 2): {cos:+.4f}") if __name__ == "__main__": main()