Spaces:
Sleeping
Sleeping
httpdaniel commited on
Commit ·
3a8a9b9
1
Parent(s): c0ecbbb
Adding system prompt
Browse files
app.py
CHANGED
|
@@ -4,9 +4,9 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
| 4 |
from langchain_chroma import Chroma
|
| 5 |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
| 6 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
| 7 |
-
from
|
| 8 |
-
from
|
| 9 |
-
from
|
| 10 |
|
| 11 |
def initialise_vectorstore(pdf, progress=gr.Progress()):
|
| 12 |
progress(0, desc="Reading PDF")
|
|
@@ -35,8 +35,8 @@ def initialise_chain(llm, vectorstore, progress=gr.Progress()):
|
|
| 35 |
repo_id=llm,
|
| 36 |
task="text-generation",
|
| 37 |
max_new_tokens=512,
|
| 38 |
-
|
| 39 |
-
|
| 40 |
)
|
| 41 |
|
| 42 |
chat = ChatHuggingFace(
|
|
@@ -47,23 +47,36 @@ def initialise_chain(llm, vectorstore, progress=gr.Progress()):
|
|
| 47 |
progress(0.5, desc="Initialising RAG Chain")
|
| 48 |
|
| 49 |
retriever = vectorstore.as_retriever()
|
| 50 |
-
prompt = hub.pull("rlm/rag-prompt")
|
| 51 |
-
parser = StrOutputParser()
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
return rag_chain, progress
|
| 58 |
|
| 59 |
def send(message, rag_chain, chat_history):
|
| 60 |
-
response = rag_chain.invoke(message)
|
| 61 |
-
chat_history.append((message, response))
|
| 62 |
-
return "", chat_history
|
| 63 |
-
|
| 64 |
-
def restart():
|
| 65 |
-
return f"Restarting"
|
| 66 |
|
|
|
|
| 67 |
|
| 68 |
with gr.Blocks() as demo:
|
| 69 |
|
|
@@ -74,7 +87,6 @@ with gr.Blocks() as demo:
|
|
| 74 |
gr.Markdown("<H3>Upload and ask questions about your PDF files</H3>")
|
| 75 |
gr.Markdown("<H6>Note: This project uses LangChain to perform RAG (Retrieval Augmented Generation) on PDF files, allowing users to ask any questions related to their contents. When a PDF file is uploaded, it is embedded and stored in an in-memory Chroma vectorstore, which the chatbot uses as a source of knowledge when aswering user questions.</H6>")
|
| 76 |
|
| 77 |
-
# Vectorstore Tab
|
| 78 |
with gr.Tab("Vectorstore"):
|
| 79 |
with gr.Row():
|
| 80 |
input_pdf = gr.File()
|
|
@@ -91,10 +103,9 @@ with gr.Blocks() as demo:
|
|
| 91 |
with gr.Row():
|
| 92 |
vectorstore_initialisation_progress = gr.Textbox(value="None", label="Initialization")
|
| 93 |
|
| 94 |
-
# RAG Chain
|
| 95 |
with gr.Tab("RAG Chain"):
|
| 96 |
with gr.Row():
|
| 97 |
-
language_model = gr.Radio(["microsoft/Phi-3-mini-4k-instruct", "mistralai/Mistral-7B-Instruct-v0.2", "
|
| 98 |
with gr.Row():
|
| 99 |
with gr.Column(scale=1, min_width=0):
|
| 100 |
pass
|
|
@@ -108,35 +119,14 @@ with gr.Blocks() as demo:
|
|
| 108 |
with gr.Row():
|
| 109 |
chain_initialisation_progress = gr.Textbox(value="None", label="Initialization")
|
| 110 |
|
| 111 |
-
# Chatbot Tab
|
| 112 |
with gr.Tab("Chatbot"):
|
| 113 |
with gr.Row():
|
| 114 |
chatbot = gr.Chatbot()
|
| 115 |
-
with gr.Accordion("Advanced - Document references", open=False):
|
| 116 |
-
with gr.Row():
|
| 117 |
-
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
|
| 118 |
-
source1_page = gr.Number(label="Page", scale=1)
|
| 119 |
-
with gr.Row():
|
| 120 |
-
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
|
| 121 |
-
source2_page = gr.Number(label="Page", scale=1)
|
| 122 |
-
with gr.Row():
|
| 123 |
-
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
| 124 |
-
source3_page = gr.Number(label="Page", scale=1)
|
| 125 |
with gr.Row():
|
| 126 |
message = gr.Textbox()
|
| 127 |
-
with gr.Row():
|
| 128 |
-
send_btn = gr.Button(
|
| 129 |
-
"Send",
|
| 130 |
-
variant=["primary"]
|
| 131 |
-
)
|
| 132 |
-
restart_btn = gr.Button(
|
| 133 |
-
"Restart",
|
| 134 |
-
variant=["secondary"]
|
| 135 |
-
)
|
| 136 |
|
| 137 |
initialise_vectorstore_btn.click(fn=initialise_vectorstore, inputs=input_pdf, outputs=[vectorstore, vectorstore_initialisation_progress])
|
| 138 |
initialise_chain_btn.click(fn=initialise_chain, inputs=[language_model, vectorstore], outputs=[rag_chain, chain_initialisation_progress])
|
| 139 |
-
|
| 140 |
-
restart_btn.click(fn=restart)
|
| 141 |
|
| 142 |
demo.launch()
|
|
|
|
| 4 |
from langchain_chroma import Chroma
|
| 5 |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
| 6 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
| 7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 8 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 9 |
+
from langchain.chains import create_retrieval_chain
|
| 10 |
|
| 11 |
def initialise_vectorstore(pdf, progress=gr.Progress()):
|
| 12 |
progress(0, desc="Reading PDF")
|
|
|
|
| 35 |
repo_id=llm,
|
| 36 |
task="text-generation",
|
| 37 |
max_new_tokens=512,
|
| 38 |
+
top_k=4,
|
| 39 |
+
temperature=0.1
|
| 40 |
)
|
| 41 |
|
| 42 |
chat = ChatHuggingFace(
|
|
|
|
| 47 |
progress(0.5, desc="Initialising RAG Chain")
|
| 48 |
|
| 49 |
retriever = vectorstore.as_retriever()
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
system_prompt = (
|
| 52 |
+
"You are an assistant for question-answering tasks. "
|
| 53 |
+
"Use the following pieces of retrieved context to answer "
|
| 54 |
+
"the question. If you don't know the answer, say that you "
|
| 55 |
+
"don't know. Use three sentences maximum and keep the "
|
| 56 |
+
"answer concise."
|
| 57 |
+
"\n\n"
|
| 58 |
+
"{context}"
|
| 59 |
+
)
|
| 60 |
|
| 61 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 62 |
+
[
|
| 63 |
+
("system", system_prompt),
|
| 64 |
+
("human", "{input}"),
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
question_answer_chain = create_stuff_documents_chain(chat, prompt)
|
| 69 |
+
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
| 70 |
+
|
| 71 |
+
progress(0.9, desc="Complete")
|
| 72 |
|
| 73 |
return rag_chain, progress
|
| 74 |
|
| 75 |
def send(message, rag_chain, chat_history):
|
| 76 |
+
response = rag_chain.invoke({"input": message})
|
| 77 |
+
chat_history.append((message, response["answer"]))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
return "", chat_history
|
| 80 |
|
| 81 |
with gr.Blocks() as demo:
|
| 82 |
|
|
|
|
| 87 |
gr.Markdown("<H3>Upload and ask questions about your PDF files</H3>")
|
| 88 |
gr.Markdown("<H6>Note: This project uses LangChain to perform RAG (Retrieval Augmented Generation) on PDF files, allowing users to ask any questions related to their contents. When a PDF file is uploaded, it is embedded and stored in an in-memory Chroma vectorstore, which the chatbot uses as a source of knowledge when aswering user questions.</H6>")
|
| 89 |
|
|
|
|
| 90 |
with gr.Tab("Vectorstore"):
|
| 91 |
with gr.Row():
|
| 92 |
input_pdf = gr.File()
|
|
|
|
| 103 |
with gr.Row():
|
| 104 |
vectorstore_initialisation_progress = gr.Textbox(value="None", label="Initialization")
|
| 105 |
|
|
|
|
| 106 |
with gr.Tab("RAG Chain"):
|
| 107 |
with gr.Row():
|
| 108 |
+
language_model = gr.Radio(["microsoft/Phi-3-mini-4k-instruct", "mistralai/Mistral-7B-Instruct-v0.2", "HuggingFaceH4/zephyr-7b-beta", "mistralai/Mixtral-8x7B-Instruct-v0.1"])
|
| 109 |
with gr.Row():
|
| 110 |
with gr.Column(scale=1, min_width=0):
|
| 111 |
pass
|
|
|
|
| 119 |
with gr.Row():
|
| 120 |
chain_initialisation_progress = gr.Textbox(value="None", label="Initialization")
|
| 121 |
|
|
|
|
| 122 |
with gr.Tab("Chatbot"):
|
| 123 |
with gr.Row():
|
| 124 |
chatbot = gr.Chatbot()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
with gr.Row():
|
| 126 |
message = gr.Textbox()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
initialise_vectorstore_btn.click(fn=initialise_vectorstore, inputs=input_pdf, outputs=[vectorstore, vectorstore_initialisation_progress])
|
| 129 |
initialise_chain_btn.click(fn=initialise_chain, inputs=[language_model, vectorstore], outputs=[rag_chain, chain_initialisation_progress])
|
| 130 |
+
message.submit(fn=send, inputs=[message, rag_chain, chatbot], outputs=[message, chatbot])
|
|
|
|
| 131 |
|
| 132 |
demo.launch()
|