| import google.generativeai as genai
|
| from chromadb import Documents, EmbeddingFunction, Embeddings, PersistentClient, Collection
|
| from typing import Dict, List
|
| import os
|
| from dotenv import load_dotenv
|
| load_dotenv(override=True)
|
| from text_chunk import *
|
|
|
| class GeminiEmbeddingFuction(EmbeddingFunction):
|
| """
|
| Custom embedding function using the Gemini AI API for document retrieval.
|
|
|
| This class extends the EmbeddingFunction class and implements the __call__ method
|
| to generate embeddings for a given set of documents using the Gemini AI API.
|
|
|
| Parameters:
|
| - input (Documents): A collection of documents to be embedded.
|
|
|
| Returns:
|
| - Embeddings: Embeddings generated for the input documents.
|
| """
|
|
|
| def __call__(self, input: Documents) -> Embeddings:
|
| genai.configure(api_key=os.getenv("GEMINI_API"))
|
| return genai.embed_content(model = "models/embedding-001",
|
| content= input,
|
| task_type="retrieval_document",
|
| title="Query")['embedding']
|
|
|
|
|
| def create_chroma_db(documents: List[str], path: str, name: str):
|
| """
|
| Creates a Chroma database using the provided documents, path, and collection name.
|
|
|
| Parameters:
|
| - documents: An iterable of documents to be added to the Chroma database.
|
| - path (str): The path where the Chroma database will be stored.
|
| - name (str): The name of the collection within the Chroma database.
|
|
|
| Returns:
|
| - Tuple[chromadb.Collection, str]: A tuple containing the created Chroma Collection and its name.
|
| """
|
|
|
| chroma_client = PersistentClient(path=path)
|
| db = chroma_client.create_collection(name=name,
|
| embedding_function=GeminiEmbeddingFuction())
|
| for i, d in enumerate(documents):
|
| db.add(documents=[d], ids = str(i))
|
| return db, name
|
|
|
| def load_chroma_db(path: str, name: str):
|
| """
|
| Loads an existing Chroma collection from the specified path with the given name.
|
|
|
| Parameters:
|
| - path (str): The path where the Chroma database is stored.
|
| - name (str): The name of the collection within the Chroma database.
|
|
|
| Returns:
|
| - chromadb.Collection: The loaded Chroma Collection.
|
| """
|
|
|
| chroma_client = PersistentClient(path=path)
|
| db = chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFuction())
|
| return db
|
|
|
| def get_relevant_passage(query: str, db: Collection, n_results: int):
|
| """
|
| semantic search to retrieve the most similar chunks of text from the database.
|
|
|
| Parameters:
|
| query (str): The query to search for.
|
| n_results (int): The number of results to return.
|
| db (chromadb.Collection): The Chroma collection to search.
|
|
|
| Returns:
|
| List[str]: A list of the most similar chunks of text.
|
| """
|
| passage = db.query(query_texts=[query],
|
| n_results=n_results)['documents'][0]
|
| return passage
|
|
|
| if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| print("Done")
|
|
|
|
|
|
|
|
|
| |