| |
| """Run inference on the Embedl All Minilm L6 V2 INT8 sentence encoder via torch.export. |
| |
| Loads the shipped ``embedl_all-MiniLM-L6-v2_int8.pt2`` artifact with |
| ``torch.export.load`` and encodes a sentence (or pair of sentences) |
| into an L2-normalised embedding. No TensorRT or ONNX runtime is |
| required — just PyTorch + transformers (for the tokenizer). |
| |
| Usage:: |
| |
| python infer_pt2.py --sentence "A man is eating food." |
| python infer_pt2.py --sentence "A man is eating." \\ |
| --sentence "A man is having a meal." |
| """ |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoTokenizer |
|
|
| PT2_PATH = Path(__file__).with_name("embedl_all-MiniLM-L6-v2_int8.pt2") |
| TOKENIZER_ID = "sentence-transformers/all-MiniLM-L6-v2" |
| MAX_LENGTH = 128 |
|
|
|
|
| def encode(model: torch.nn.Module, tokenizer, sentence: str) -> torch.Tensor: |
| enc = tokenizer( |
| sentence, |
| padding="max_length", |
| truncation=True, |
| max_length=MAX_LENGTH, |
| return_tensors="pt", |
| ) |
| with torch.no_grad(): |
| embedding = model(enc["input_ids"], enc["attention_mask"]) |
| return embedding.squeeze(0) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--sentence", |
| required=True, |
| action="append", |
| help="Sentence to encode. Pass twice to also print cosine similarity.", |
| ) |
| args = parser.parse_args() |
|
|
| if not PT2_PATH.exists(): |
| raise SystemExit( |
| f"Expected {PT2_PATH.name} next to this script. " |
| "Did you `huggingface-cli download` the repo?" |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID) |
| |
| |
| |
| model = torch.export.load(str(PT2_PATH)).module() |
|
|
| embeddings = [encode(model, tokenizer, s) for s in args.sentence] |
|
|
| for i, (sentence, emb) in enumerate(zip(args.sentence, embeddings), 1): |
| first8 = ", ".join(f"{v:+.4f}" for v in emb[:8].tolist()) |
| print(f"[{i}] {sentence!r}") |
| print(f" embedding shape: {tuple(emb.shape)}") |
| print(f" first 8 dims: [{first8}]") |
|
|
| if len(embeddings) >= 2: |
| cos = torch.dot(embeddings[0], embeddings[1]).item() |
| print(f"\ncosine similarity (sentences 1 & 2): {cos:+.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|