| from application.vectorstore.base import BaseVectorStore |
| from application.core.settings import settings |
| from application.vectorstore.document_class import Document |
| import elasticsearch |
|
|
|
|
|
|
|
|
| class ElasticsearchStore(BaseVectorStore): |
| _es_connection = None |
|
|
| def __init__(self, path, embeddings_key, index_name=settings.ELASTIC_INDEX): |
| super().__init__() |
| self.path = path.replace("application/indexes/", "").rstrip("/") |
| self.embeddings_key = embeddings_key |
| self.index_name = index_name |
| |
| if ElasticsearchStore._es_connection is None: |
| connection_params = {} |
| if settings.ELASTIC_URL: |
| connection_params["hosts"] = [settings.ELASTIC_URL] |
| connection_params["http_auth"] = (settings.ELASTIC_USERNAME, settings.ELASTIC_PASSWORD) |
| elif settings.ELASTIC_CLOUD_ID: |
| connection_params["cloud_id"] = settings.ELASTIC_CLOUD_ID |
| connection_params["basic_auth"] = (settings.ELASTIC_USERNAME, settings.ELASTIC_PASSWORD) |
| else: |
| raise ValueError("Please provide either elasticsearch_url or cloud_id.") |
|
|
| |
|
|
| ElasticsearchStore._es_connection = elasticsearch.Elasticsearch(**connection_params) |
| |
| self.docsearch = ElasticsearchStore._es_connection |
|
|
| def connect_to_elasticsearch( |
| *, |
| es_url = None, |
| cloud_id = None, |
| api_key = None, |
| username = None, |
| password = None, |
| ): |
| try: |
| import elasticsearch |
| except ImportError: |
| raise ImportError( |
| "Could not import elasticsearch python package. " |
| "Please install it with `pip install elasticsearch`." |
| ) |
|
|
| if es_url and cloud_id: |
| raise ValueError( |
| "Both es_url and cloud_id are defined. Please provide only one." |
| ) |
|
|
| connection_params = {} |
|
|
| if es_url: |
| connection_params["hosts"] = [es_url] |
| elif cloud_id: |
| connection_params["cloud_id"] = cloud_id |
| else: |
| raise ValueError("Please provide either elasticsearch_url or cloud_id.") |
|
|
| if api_key: |
| connection_params["api_key"] = api_key |
| elif username and password: |
| connection_params["basic_auth"] = (username, password) |
|
|
| es_client = elasticsearch.Elasticsearch( |
| **connection_params, |
| ) |
| try: |
| es_client.info() |
| except Exception as e: |
| raise e |
|
|
| return es_client |
|
|
| def search(self, question, k=2, index_name=settings.ELASTIC_INDEX, *args, **kwargs): |
| embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) |
| vector = embeddings.embed_query(question) |
| knn = { |
| "filter": [{"match": {"metadata.store.keyword": self.path}}], |
| "field": "vector", |
| "k": k, |
| "num_candidates": 100, |
| "query_vector": vector, |
| } |
| full_query = { |
| "knn": knn, |
| "query": { |
| "bool": { |
| "must": [ |
| { |
| "match": { |
| "text": { |
| "query": question, |
| } |
| } |
| } |
| ], |
| "filter": [{"match": {"metadata.store.keyword": self.path}}], |
| } |
| }, |
| "rank": {"rrf": {}}, |
| } |
| resp = self.docsearch.search(index=self.index_name, query=full_query['query'], size=k, knn=full_query['knn']) |
| |
| doc_list = [] |
| for hit in resp['hits']['hits']: |
| |
| doc_list.append(Document(page_content = hit['_source']['text'], metadata = hit['_source']['metadata'])) |
| return doc_list |
|
|
| def _create_index_if_not_exists( |
| self, index_name, dims_length |
| ): |
|
|
| if self._es_connection.indices.exists(index=index_name): |
| print(f"Index {index_name} already exists.") |
|
|
| else: |
|
|
| indexSettings = self.index( |
| dims_length=dims_length, |
| ) |
| self._es_connection.indices.create(index=index_name, **indexSettings) |
|
|
| def index( |
| self, |
| dims_length, |
| ): |
| return { |
| "mappings": { |
| "properties": { |
| "vector": { |
| "type": "dense_vector", |
| "dims": dims_length, |
| "index": True, |
| "similarity": "cosine", |
| }, |
| } |
| } |
| } |
|
|
| def add_texts( |
| self, |
| texts, |
| metadatas = None, |
| ids = None, |
| refresh_indices = True, |
| create_index_if_not_exists = True, |
| bulk_kwargs = None, |
| **kwargs, |
| ): |
| |
| from elasticsearch.helpers import BulkIndexError, bulk |
|
|
| bulk_kwargs = bulk_kwargs or {} |
| import uuid |
| embeddings = [] |
| ids = ids or [str(uuid.uuid4()) for _ in texts] |
| requests = [] |
| embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) |
|
|
| vectors = embeddings.embed_documents(list(texts)) |
|
|
| dims_length = len(vectors[0]) |
|
|
| if create_index_if_not_exists: |
| self._create_index_if_not_exists( |
| index_name=self.index_name, dims_length=dims_length |
| ) |
|
|
| for i, (text, vector) in enumerate(zip(texts, vectors)): |
| metadata = metadatas[i] if metadatas else {} |
|
|
| requests.append( |
| { |
| "_op_type": "index", |
| "_index": self.index_name, |
| "text": text, |
| "vector": vector, |
| "metadata": metadata, |
| "_id": ids[i], |
| } |
| ) |
|
|
|
|
| if len(requests) > 0: |
| try: |
| success, failed = bulk( |
| self._es_connection, |
| requests, |
| stats_only=True, |
| refresh=refresh_indices, |
| **bulk_kwargs, |
| ) |
| return ids |
| except BulkIndexError as e: |
| print(f"Error adding texts: {e}") |
| firstError = e.errors[0].get("index", {}).get("error", {}) |
| print(f"First error reason: {firstError.get('reason')}") |
| raise e |
|
|
| else: |
| return [] |
|
|
| def delete_index(self): |
| self._es_connection.delete_by_query(index=self.index_name, query={"match": { |
| "metadata.store.keyword": self.path}},) |
|
|
|
|