| from langchain_core.runnables import RunnableLambda |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings |
|
|
| from src.db_utils.qdrant_utils import qdrant_search |
|
|
|
|
| class Retriever: |
|
|
| def __init__(self, embed_model_name: str, embed_index_name: str): |
| self.embed_model = HuggingFaceEmbeddings( |
| model_name=embed_model_name, |
| encode_kwargs={"normalize_embeddings": True}, |
| ) |
| self.embed_index_name = embed_index_name |
|
|
| self.chain = RunnableLambda(self._retrieve) |
|
|
| |
| def _retrieve(self, query: str) -> str: |
| docs = qdrant_search( |
| self.embed_index_name, |
| self.embed_model.embed_query(query), |
| ) |
| return "\n".join( |
| f"{i}) {doc.payload['text']}" |
| for i, doc in enumerate(docs.points, 1) |
| ) |
|
|