| """ |
| Computes embeddings |
| """ |
|
|
|
|
| import unittest |
| from sentence_transformers import SentenceTransformer |
| import numpy as np |
|
|
| class ComputeEmbeddingsTest(unittest.TestCase): |
| def setUp(self): |
| self.model = SentenceTransformer('paraphrase-distilroberta-base-v1') |
|
|
|
|
| def test_encode_token_embeddings(self): |
| """ |
| Test that encode(output_value='token_embeddings') works |
| :return: |
| """ |
| sent = ["Hello Word, a test sentence", "Here comes another sentence", "My final sentence", "Sentences", "Sentence five five five five five five five"] |
| emb = self.model.encode(sent, output_value='token_embeddings', batch_size=2) |
| assert len(emb) == len(sent) |
| for s, e in zip(sent, emb): |
| assert len(self.model.tokenize([s])['input_ids'][0]) == e.shape[0] |
|
|
|
|
| def test_encode_single_sentences(self): |
| |
| emb = self.model.encode("Hello Word, a test sentence") |
| assert emb.shape == (768,) |
| assert abs(np.sum(emb) - 7.9811716) < 0.001 |
|
|
| |
| emb = self.model.encode(["Hello Word, a test sentence"]) |
| assert emb.shape == (1, 768) |
| assert abs(np.sum(emb) - 7.9811716) < 0.001 |
|
|
| |
| emb = self.model.encode(["Hello Word, a test sentence", "Here comes another sentence", "My final sentence"]) |
| assert emb.shape == (3, 768) |
| assert abs(np.sum(emb) - 22.968266) < 0.001 |
|
|
| def test_encode_normalize(self): |
| emb = self.model.encode(["Hello Word, a test sentence", "Here comes another sentence", "My final sentence"], normalize_embeddings=True) |
| assert emb.shape == (3, 768) |
| for norm in np.linalg.norm(emb, axis=1): |
| assert abs(norm - 1) < 0.001 |
|
|
| def test_encode_tuple_sentences(self): |
| |
| emb = self.model.encode([("Hello Word, a test sentence", "Second input for model")]) |
| assert emb.shape == (1, 768) |
| assert abs(np.sum(emb) - 9.503508) < 0.001 |
|
|
| |
| emb = self.model.encode([("Hello Word, a test sentence", "Second input for model"), ("My second tuple", "With two inputs"), ("Final tuple", "final test")]) |
| assert emb.shape == (3, 768) |
| assert abs(np.sum(emb) - 32.14627) < 0.001 |
|
|
|
|
|
|
|
|
|
|
|
|