| import os |
| import httpx |
| from llama_index.core import Document |
| from llama_index.core import Settings |
| from llama_index.core import SimpleDirectoryReader |
| from llama_index.core import StorageContext |
| from llama_index.core import VectorStoreIndex |
|
|
| from llama_index.vector_stores.chroma import ChromaVectorStore |
|
|
| import chromadb |
| import re |
| from llama_index.llms.cohere import Cohere |
| from llama_index.embeddings.cohere import CohereEmbedding |
|
|
|
|
| from llama_index.core.memory import ChatMemoryBuffer |
| from llama_index.core.chat_engine import CondensePlusContextChatEngine |
|
|
| import gradio as gr |
| import uuid |
|
|
| api_key = os.environ.get("API_KEY") |
| base_url = os.environ.get("BASE_URL") |
|
|
| llm = Cohere( |
| api_key=api_key, |
| model="command") |
| embedding_model = CohereEmbedding( |
| api_key=api_key, |
| model_name="embed-multilingual-v3.0", |
| input_type="search_document", |
| embedding_type="int8",) |
|
|
|
|
| memory = "" |
|
|
| |
| Settings.llm = llm |
| Settings.embed_model=embedding_model |
| |
| Settings.context_window = 4096 |
| |
| Settings.num_output = 512 |
|
|
|
|
|
|
| db_path="" |
|
|
| def validate_url(url): |
| try: |
| response = httpx.get(url, timeout=60.0) |
| response.raise_for_status() |
| text = [Document(text=response.text)] |
| option = "web" |
| return text, option |
| except httpx.RequestError as e: |
| raise gr.Error(f"An error occurred while requesting {url}: {str(e)}") |
| except httpx.HTTPStatusError as e: |
| raise gr.Error(f"Error response {e.response.status_code} while requesting {url}") |
| except Exception as e: |
| raise gr.Error(f"An unexpected error occurred: {str(e)}") |
|
|
| def extract_web(url): |
| print("Entered Webpage Extraction") |
| prefix_url = "https://r.jina.ai/" |
| full_url = prefix_url + url |
| print(full_url) |
| print("Exited Webpage Extraction") |
| return validate_url(full_url) |
|
|
| def extract_doc(path): |
| documents = SimpleDirectoryReader(input_files=path).load_data() |
| option = "doc" |
| return documents, option |
|
|
|
|
| def create_col(documents): |
| |
| db_path = f'database/{str(uuid.uuid4())[:4]}' |
| client = chromadb.PersistentClient(path=db_path) |
| chroma_collection = client.get_or_create_collection("quickstart") |
| |
| |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
| |
| |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) |
| |
| VectorStoreIndex.from_documents( |
| documents, storage_context=storage_context |
| ) |
| return db_path |
|
|
| def infer(message:str, history: list): |
| global db_path |
| global memory |
| option="" |
| print(f'message: {message}') |
| print(f'history: {history}') |
| messages = [] |
| files_list = message["files"] |
| |
| |
| if files_list: |
| documents, option = extract_doc(files_list) |
| db_path = create_col(documents) |
| memory = ChatMemoryBuffer.from_defaults(token_limit=3900) |
| else: |
| if message["text"].startswith("http://") or message["text"].startswith("https://"): |
| documents, option = extract_web(message["text"]) |
| db_path = create_col(documents) |
| memory = ChatMemoryBuffer.from_defaults(token_limit=3900) |
| elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0: |
| raise gr.Error("Please send an URL or document") |
| |
|
|
| |
| load_client = chromadb.PersistentClient(path=db_path) |
| |
| |
| chroma_collection = load_client.get_collection("quickstart") |
| |
| |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
| |
| |
| index = VectorStoreIndex.from_vector_store( |
| vector_store, |
| ) |
|
|
| if option == "web" and len(history) == 0: |
| response = "Getcha! Now ask your question." |
| else: |
| question = message['text'] |
|
|
| chat_engine = CondensePlusContextChatEngine.from_defaults( |
| index.as_retriever(), |
| memory=memory, |
| context_prompt=( |
| "You are an assistant for question-answering tasks." |
| "Use the following context to answer the question:\n" |
| "{context_str}" |
| "\nIf you don't know the answer, just say that you don't know." |
| "Use five sentences maximum and keep the answer concise." |
| "\nInstruction: Use the previous chat history, or the context above, to interact and help the user." |
| ), |
| verbose=True, |
| ) |
| response = chat_engine.chat( |
| question |
| ) |
| |
| print(type(response)) |
| print(f'response: {response}') |
| |
|
|
| return str(response) |
| |
|
|
|
|
| css=""" |
| footer { |
| display:none !important |
| } |
| h1 { |
| text-align: center; |
| display: block; |
| } |
| """ |
|
|
| title=""" |
| <h1>RAG demo</h1> |
| <p style="text-align: center">Retrieval for web and documents</p> |
| """ |
|
|
|
|
| chatbot = gr.Chatbot(placeholder="Please send an URL or document file at first<br>Then ask question and get an answer.", height=800) |
|
|
| with gr.Blocks(theme="soft", css=css, fill_height="true") as demo: |
| gr.ChatInterface( |
| fn = infer, |
| title = title, |
| multimodal = True, |
| chatbot = chatbot, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue(api_open=False).launch(show_api=False, share=False) |