| import chromadb |
| from .BaseDB import BaseDB |
| import random |
| import string |
| import os |
| from tqdm import tqdm |
|
|
| class ChromaDB(BaseDB): |
| |
| def __init__(self,embedding,save_type = "persistent"): |
| self.collections = {} |
| self.embedding = embedding |
|
|
| base_dir = os.path.dirname(os.path.abspath(__file__)) |
| if save_type == "persistent": |
| self.path = os.path.join(base_dir, "./chromadb_saves/") |
| self.client = chromadb.PersistentClient(path = self.path) |
| else: |
| self.client = chromadb.Client() |
|
|
| def init_from_data(self, data, db_name): |
| if db_name in [c.name for c in self.client.list_collections()]: |
| self.collections[db_name] = self.client.get_collection(name=db_name,embedding_function=self.embedding) |
| else: |
| self.collections[db_name] = self.client.create_collection(name=db_name,embedding_function=self.embedding) |
| if len(data) != 0: |
| for i in tqdm(list(range(self.collections[db_name].count()+1,len(data)))): |
| self.collections[db_name].update( |
| documents=[data[i]], |
| ids=[str(i)] |
| ) |
| return |
|
|
| def search(self, query, n_results, db_name): |
| if db_name not in self.collections:return [] |
| n_results = min(self.collections[db_name].count(), n_results) |
| if n_results < 1: |
| return [] |
| results = self.collections[db_name].query(query_texts=[query], n_results=n_results) |
| return results['documents'][0] |
| |
| def add(self,text,idx, db_name=""): |
| if db_name not in self.collections: |
| self.collections[db_name] = self.client.create_collection( |
| name=db_name, |
| embedding_function=self.embedding |
| ) |
| self.collections[db_name].add( |
| documents=[text], |
| ids=[idx] |
| ) |
| return |
|
|
| collection = self.collections[db_name] |
|
|
| existing_doc = collection.get(ids=[idx]) |
|
|
| if existing_doc and existing_doc['ids']: |
| collection.update( |
| documents=[text], |
| ids=[idx] |
| ) |
| else: |
| collection.add( |
| documents=[text], |
| ids=[idx] |
| ) |
| def delete(self,idx,db_name): |
| self.collections[db_name].delete(ids=[idx]) |
|
|
|
|
|
|