Spaces:
Build error
Build error
| # 1. Using langchain Vector store | |
| # https://python.langchain.com/v0.1/docs/modules/data_connection/vectorstores/ | |
| # VectorStore - FAISS | |
| # 2. Embedding - HuggingFaceInferenceAPIEmbeddings with "BAAI/bge-base-en-v1.5" | |
| # 3. llm use mistral and llama. | |
| # "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2" | |
| # "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" | |
| import gradio as gr | |
| import os | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain_text_splitters import CharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
| API_TOKEN = os.environ.get('HUGGINGFACE_API_KEY') | |
| HF_API_KEY = API_TOKEN | |
| llm_urls = { | |
| "Mistral 7B": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2", | |
| "Llama 8B": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" | |
| } | |
| def initialize_vector_store_retriever(file): | |
| # Load the document, split it into chunks, embed each chunk and load it into the vector store. | |
| #raw_documents = TextLoader('./llm.txt').load() | |
| raw_documents = TextLoader(file).load() | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
| documents = text_splitter.split_documents(raw_documents) | |
| API_URL = "https://api-inference.huggingface.co/models/BAAI/bge-base-en-v1.5" | |
| embeddings = HuggingFaceInferenceAPIEmbeddings( | |
| endpoint_url=API_URL, | |
| api_key=HF_API_KEY, | |
| ) | |
| db = FAISS.from_documents(documents, embeddings) | |
| retriever = db.as_retriever() | |
| return retriever | |
| def generate_llm_rag_prompt() -> ChatPromptTemplate: | |
| #template = "<s>[INST] {context} {prompt} [/INST]" | |
| template = "<s>[INST] <<SYS>>{system}<</SYS>>{context} {prompt} [/INST]" | |
| prompt_template = ChatPromptTemplate.from_template(template) | |
| return prompt_template | |
| def create_chain(retriever, llm): | |
| url = llm_urls[llm] | |
| model_endpoint = HuggingFaceEndpoint( | |
| endpoint_url=url, | |
| huggingfacehub_api_token=HF_API_KEY, | |
| task="text2text-generation", | |
| max_new_tokens=200 | |
| ) | |
| if retriever != None: | |
| def get_system(input): | |
| return "You are a helpful and honest assistant. Please, respond concisely and truthfully." | |
| retrieval = {"context": retriever, "prompt": RunnablePassthrough(), "system": get_system} | |
| chain = retrieval | generate_llm_rag_prompt() | model_endpoint | |
| return chain, model_endpoint | |
| else: | |
| return None, model_endpoint | |
| def query(question_text, llm, session_data): | |
| if question_text == "": | |
| without_rag_text = "Query result without RAG is not available. Enter a question first." | |
| rag_text = "Query result with RAG is not available. Enter a question first." | |
| return without_rag_text, rag_text | |
| if len(session_data)>0: | |
| retriever = session_data[0] | |
| else: | |
| retriever = None | |
| chain, model_endpoint = create_chain(retriever, llm) | |
| without_rag_text = "Query result without RAG:\n\n" + model_endpoint(question_text).strip() | |
| if (retriever == None): | |
| rag_text = "Query result With RAG is not available. Load Vector Store first." | |
| else: | |
| ans = chain.invoke(question_text).strip() | |
| s = ans | |
| s = [s.split("[INST] <<SYS>>")[1] for s in s.split("[/SYS]>[/INST]") if s.find("[INST] <<SYS>>") >=0] | |
| if len(s) >= 2: | |
| s = s[1:-1] | |
| else: | |
| s = ans | |
| rag_text = "Query result With RAG:\n\n" + "".join(s).split("[/INST]")[0] | |
| return without_rag_text, rag_text | |
| def upload_file(file, session_data): | |
| #file_paths = [file.name for file in files] | |
| #file = files[0] | |
| session_data = [initialize_vector_store_retriever(file)] | |
| return gr.File(value=file, visible=True), session_data | |
| def initialize_vector_store(session_data): | |
| session_data = [initialize_vector_store_retriever()] | |
| return session_data | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">Retrieval Augmented Generation</h1>""") | |
| session_data = gr.State([]) | |
| file_output = gr.File(visible=False) | |
| upload_button = gr.UploadButton("Click to Upload a text File to Vector Store", file_types=["text"], file_count="single") | |
| upload_button.upload(upload_file, [upload_button, session_data], [file_output, session_data]) | |
| #initialize_VS_button = gr.Button("Load text file to Vector Store") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| question_text = gr.Textbox(show_label=False, placeholder="Ask a question", lines=2) | |
| with gr.Column(scale=1): | |
| llm_Choice = gr.Radio(["Llama 8B", "Mistral 7B"], value="Mistral 7B", label="Select lanaguage model:", info="") | |
| query_Button = gr.Button("Query") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| without_rag_text = gr.Textbox(show_label=False, placeholder="Query result without using RAG", lines=15) | |
| with gr.Column(scale=1): | |
| rag_text = gr.Textbox(show_label=False, placeholder="Query result with RAG", lines=15) | |
| #initialize_VS_button.click( | |
| # initialize_vector_store, | |
| # [session_data], | |
| # [session_data], | |
| # #show_progress=True, | |
| #) | |
| query_Button.click( | |
| query, | |
| [question_text, llm_Choice, session_data], | |
| [without_rag_text, rag_text], | |
| #show_progress=True, | |
| ) | |
| demo.queue().launch(share=False, inbrowser=True) |