| import chromadb |
| from datetime import datetime |
|
|
| chroma_client = chromadb.Client() |
|
|
|
|
| def get_or_create_collection(coll_name: str): |
| date = coll_name[:6] |
| coll = chroma_client.get_or_create_collection(name=coll_name, metadata={"date": date}) |
| return coll |
|
|
|
|
| def get_collection(coll_name: str): |
| coll = chroma_client.get_collection(name=coll_name) |
| return coll |
|
|
|
|
| def reset_collection(coll_name: str): |
| coll = chroma_client.get_collection(name=coll_name) |
| coll.delete() |
| return coll |
|
|
|
|
| def delete_old_collections(old=2): |
| collections = chroma_client.list_collections() |
| current_hour = int(datetime.now().strftime("%m%d%H")) |
|
|
| for coll in collections: |
| coll_hour = int(coll.metadata['date']) |
| if coll_hour < current_hour - old: |
| chroma_client.delete_collection(coll.name) |
|
|
|
|
| def add_texts_to_collection(coll_name: str, texts: [str], file: str, source: str): |
| """ |
| add texts to a collection : texts originate all from the same file |
| """ |
| coll = chroma_client.get_collection(name=coll_name) |
| filenames = [{file: 1, 'source': source} for _ in texts] |
| ids = [file+'-'+str(i) for i in range(len(texts))] |
| try: |
| coll.delete(ids=ids) |
| coll.add(documents=texts, metadatas=filenames, ids=ids) |
| except: |
| print(f"exception raised for collection :{coll_name}, texts: {texts} from file {file} and source {source}") |
|
|
|
|
| def delete_collection(coll_name: str): |
| chroma_client.delete_collection(name=coll_name) |
|
|
|
|
| def list_collections(): |
| return chroma_client.list_collections() |
|
|
|
|
| def query_collection(coll_name: str, query: str, from_files: [str], n_results: int = 4): |
| assert 0 < len(from_files) |
| coll = chroma_client.get_collection(name=coll_name) |
| where_ = [{file: 1} for file in from_files] |
| where_ = where_[0] if len(where_) == 1 else {'$or': where_} |
| n_results_ = min(n_results, coll.count()) |
|
|
| ans = "" |
| try: |
| ans = coll.query(query_texts=query, n_results=n_results_, where=where_) |
| except: |
| print(f"exception raised at query collection for collection {coll_name} and query {query} from files " |
| f"{from_files}") |
|
|
| return ans |
|
|