Spaces:
Sleeping
Sleeping
Commit ·
5d0d255
1
Parent(s): 11ce507
MOD: batch size
Browse files- retriever/embedder.py +3 -3
retriever/embedder.py
CHANGED
|
@@ -12,7 +12,7 @@ class Embedder:
|
|
| 12 |
self.model = SentenceTransformer(model_name, device=device)
|
| 13 |
logger.info(f"Loaded embedding model: {model_name}")
|
| 14 |
|
| 15 |
-
def encode(self, texts: Union[str, List[str]], batch_size: int =
|
| 16 |
"""Encode texts to embeddings"""
|
| 17 |
if isinstance(texts, str):
|
| 18 |
texts = [texts]
|
|
@@ -26,7 +26,7 @@ class Embedder:
|
|
| 26 |
|
| 27 |
return embeddings
|
| 28 |
|
| 29 |
-
def encode_queries(self, queries: List[str], batch_size: int =
|
| 30 |
"""Encode queries with query prefix"""
|
| 31 |
if not queries:
|
| 32 |
return np.array([])
|
|
@@ -35,7 +35,7 @@ class Embedder:
|
|
| 35 |
prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries]
|
| 36 |
return self.encode(prefixed_queries, batch_size)
|
| 37 |
|
| 38 |
-
def encode_passages(self, passages: List[str], batch_size: int =
|
| 39 |
"""Encode passages with passage prefix"""
|
| 40 |
if not passages:
|
| 41 |
return np.array([])
|
|
|
|
| 12 |
self.model = SentenceTransformer(model_name, device=device)
|
| 13 |
logger.info(f"Loaded embedding model: {model_name}")
|
| 14 |
|
| 15 |
+
def encode(self, texts: Union[str, List[str]], batch_size: int = 16) -> np.ndarray:
|
| 16 |
"""Encode texts to embeddings"""
|
| 17 |
if isinstance(texts, str):
|
| 18 |
texts = [texts]
|
|
|
|
| 26 |
|
| 27 |
return embeddings
|
| 28 |
|
| 29 |
+
def encode_queries(self, queries: List[str], batch_size: int = 16) -> np.ndarray:
|
| 30 |
"""Encode queries with query prefix"""
|
| 31 |
if not queries:
|
| 32 |
return np.array([])
|
|
|
|
| 35 |
prefixed_queries = [f"Represent this sentence for searching relevant passages: {q}" for q in queries]
|
| 36 |
return self.encode(prefixed_queries, batch_size)
|
| 37 |
|
| 38 |
+
def encode_passages(self, passages: List[str], batch_size: int = 16) -> np.ndarray:
|
| 39 |
"""Encode passages with passage prefix"""
|
| 40 |
if not passages:
|
| 41 |
return np.array([])
|