| import chainlit as cl |
| from langchain.retrievers import ParentDocumentRetriever |
| from langchain.schema.runnable import RunnableConfig |
| from langchain.storage import LocalFileStore |
| from langchain.storage._lc_store import create_kv_docstore |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain.vectorstores.chroma import Chroma |
| from langchain_google_genai import ( |
| GoogleGenerativeAI, |
| GoogleGenerativeAIEmbeddings, |
| HarmBlockThreshold, |
| HarmCategory, |
| ) |
|
|
| import config |
| from prompts import prompt |
| from utils import PostMessageHandler, format_docs |
|
|
| model = GoogleGenerativeAI( |
| model=config.GOOGLE_CHAT_MODEL, |
| google_api_key=config.GOOGLE_API_KEY, |
| safety_settings={ |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
| }, |
| ) |
|
|
| embeddings_model = GoogleGenerativeAIEmbeddings( |
| model=config.GOOGLE_EMBEDDING_MODEL |
| ) |
|
|
|
|
| |
| child_splitter = RecursiveCharacterTextSplitter(chunk_size=500, separators=["\n"]) |
|
|
| |
| vectorstore = Chroma( |
| persist_directory=config.STORAGE_PATH + "vectorstore", |
| collection_name="full_documents", |
| embedding_function=embeddings_model, |
| ) |
|
|
| |
| fs = LocalFileStore(config.STORAGE_PATH + "docstore") |
| store = create_kv_docstore(fs) |
|
|
| retriever = ParentDocumentRetriever( |
| vectorstore=vectorstore, |
| docstore=store, |
| child_splitter=child_splitter, |
| ) |
|
|
|
|
| @cl.on_chat_start |
| async def on_chat_start(): |
| cl.user_session.set("retriever", retriever) |
|
|
|
|
| @cl.on_message |
| async def on_message(message: cl.Message): |
| chain = prompt | model |
| msg = cl.Message(content="") |
|
|
| async with cl.Step(type="run", name="QA Assistant"): |
| question = message.content |
| context = format_docs(retriever.get_relevant_documents(question)) |
| async for chunk in chain.astream( |
| input={"context": context, "question": question}, |
| config=RunnableConfig( |
| callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)] |
| ), |
| ): |
| await msg.stream_token(chunk) |
|
|
| await msg.send() |
|
|