| import gradio as gr |
| import utils |
| from langchain_mistralai import ChatMistralAI |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_community.vectorstores import Chroma |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain_core.runnables import RunnablePassthrough |
| import torch |
|
|
| import os |
| os.environ['MISTRAL_API_KEY'] = 'XuyOObDE7trMbpAeI7OXYr3dnmoWy3L0' |
|
|
| class VectorData(): |
| def __init__(self): |
| embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert' |
|
|
| model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True} |
|
|
| self.embeddings = HuggingFaceEmbeddings( |
| model_name=embedding_model_name, |
| model_kwargs=model_kwargs |
| ) |
|
|
| self.vectorstore = Chroma(persist_directory="chroma_db", embedding_function=self.embeddings) |
| self.retriever = self.vectorstore.as_retriever() |
| self.ingested_files = [] |
| self.prompt = ChatPromptTemplate.from_messages( |
| [ |
| ( |
| "system", |
| """Answer the question based on the given context. Dont give any ans if context is not valid to question. Always give the source of context: |
| {context} |
| """, |
| ), |
| ("human", "{question}"), |
| ] |
| ) |
| self.llm = ChatMistralAI(model="mistral-large-latest") |
| self.rag_chain = ( |
| {"context": self.retriever, "question": RunnablePassthrough()} |
| | self.prompt |
| | self.llm |
| | StrOutputParser() |
| ) |
|
|
| def add_file(self,file): |
| if file is not None: |
| self.ingested_files.append(file.name.split('/')[-1]) |
| self.retriever, self.vectorstore = utils.add_doc(file,self.vectorstore) |
| self.rag_chain = ( |
| {"context": self.retriever, "question": RunnablePassthrough()} |
| | self.prompt |
| | self.llm |
| | StrOutputParser() |
| ) |
| return [[name] for name in self.ingested_files] |
|
|
| def delete_file_by_name(self,file_name): |
| if file_name in self.ingested_files: |
| self.retriever, self.vectorstore = utils.delete_doc(file_name,self.vectorstore) |
| self.ingested_files.remove(file_name) |
| return [[name] for name in self.ingested_files] |
|
|
| def delete_all_files(self): |
| self.ingested_files.clear() |
| self.retriever, self.vectorstore = utils.delete_all_doc(self.vectorstore) |
| return [] |
| |
| data_obj = VectorData() |
|
|
| |
| def answer_question(question): |
| if question.strip(): |
| return f'{data_obj.rag_chain.invoke(question)}' |
| return "Please enter a question." |
|
|
|
|
| |
| with gr.Blocks() as rag_interface: |
| |
| gr.Markdown("# RAG Interface") |
| gr.Markdown("Manage documents and ask questions with a Retrieval-Augmented Generation (RAG) system.") |
|
|
| with gr.Row(): |
| |
| with gr.Column(): |
| gr.Markdown("### File Management") |
|
|
| |
| file_input = gr.File(label="Upload File to Ingest") |
| add_file_button = gr.Button("Ingest File") |
|
|
| |
| ingested_files_box = gr.Dataframe( |
| headers=["Files"], |
| datatype="str", |
| row_count=4, |
| interactive=False |
| ) |
|
|
| |
| delete_option = gr.Radio(choices=["Delete by File Name", "Delete All Files"], label="Delete Option") |
| file_name_input = gr.Textbox(label="Enter File Name to Delete", visible=False) |
| delete_button = gr.Button("Delete Selected") |
|
|
| |
| def toggle_file_input(option): |
| return gr.update(visible=(option == "Delete by File Name")) |
|
|
| delete_option.change(fn=toggle_file_input, inputs=delete_option, outputs=file_name_input) |
|
|
| |
| add_file_button.click( |
| fn=data_obj.add_file, |
| inputs=file_input, |
| outputs=ingested_files_box |
| ) |
|
|
| |
| def delete_action(delete_option, file_name): |
| if delete_option == "Delete by File Name" and file_name: |
| return data_obj.delete_file_by_name(file_name) |
| elif delete_option == "Delete All Files": |
| return data_obj.delete_all_files() |
| else: |
| return [[name] for name in data_obj.ingested_files] |
|
|
| delete_button.click( |
| fn=delete_action, |
| inputs=[delete_option, file_name_input], |
| outputs=ingested_files_box |
| ) |
|
|
| |
| with gr.Column(): |
| gr.Markdown("### Ask a Question") |
|
|
| |
| question_input = gr.Textbox(label="Enter your question") |
|
|
| |
| ask_button = gr.Button("Get Answer") |
| answer_output = gr.Textbox(label="Answer", interactive=False) |
|
|
| ask_button.click(fn=answer_question, inputs=question_input, outputs=answer_output) |
|
|
| |
| rag_interface.launch() |
|
|