| |
| from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage |
| from llama_index.llms.nvidia import NVIDIA |
| from llama_index.embeddings.nvidia import NVIDIAEmbedding |
| from llama_index.core.llms import ChatMessage, MessageRole |
| from langchain_nvidia_ai_endpoints import NVIDIARerank |
| from langchain_core.documents import Document as LangDocument |
| from llama_index.core import Document as LlamaDocument |
| from llama_index.core import Settings |
| from llama_parse import LlamaParse |
| import streamlit as st |
| import os |
|
|
| |
| nvidia_api_key = os.getenv("NVIDIA_KEY") |
| llamaparse_api_key = os.getenv("PARSE_KEY") |
|
|
| |
| client = NVIDIA( |
| model="meta/llama-3.1-8b-instruct", |
| api_key=nvidia_api_key, |
| temperature=0.2, |
| top_p=0.7, |
| max_tokens=1024 |
| ) |
|
|
| embed_model = NVIDIAEmbedding( |
| model="nvidia/nv-embedqa-e5-v5", |
| api_key=nvidia_api_key, |
| truncate="NONE" |
| ) |
|
|
| reranker = NVIDIARerank( |
| model="nvidia/nv-rerankqa-mistral-4b-v3", |
| api_key=nvidia_api_key, |
| ) |
|
|
| |
| Settings.embed_model = embed_model |
| Settings.llm = client |
|
|
| |
| parser = LlamaParse( |
| api_key=llamaparse_api_key, |
| result_type="markdown", |
| verbose=True |
| ) |
|
|
| |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| data_file = os.path.join(script_dir, "PhilDataset.pdf") |
|
|
| |
| documents = parser.load_data(data_file) |
| print("Document Parsed") |
|
|
| |
| def split_text(text, max_tokens=512): |
| words = text.split() |
| chunks = [] |
| current_chunk = [] |
| current_length = 0 |
|
|
| for word in words: |
| word_length = len(word) |
| if current_length + word_length + 1 > max_tokens: |
| chunks.append(" ".join(current_chunk)) |
| current_chunk = [word] |
| current_length = word_length + 1 |
| else: |
| current_chunk.append(word) |
| current_length += word_length + 1 |
|
|
| if current_chunk: |
| chunks.append(" ".join(current_chunk)) |
|
|
| return chunks |
|
|
| |
| all_embeddings = [] |
| all_documents = [] |
|
|
| for doc in documents: |
| text_chunks = split_text(doc.text) |
| for chunk in text_chunks: |
| embedding = embed_model.get_text_embedding(chunk) |
| all_embeddings.append(embedding) |
| all_documents.append(LlamaDocument(text=chunk)) |
| print("Embeddings generated") |
|
|
| |
| index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model) |
| index.set_index_id("vector_index") |
| index.storage_context.persist("./storage") |
| print("Index created") |
|
|
| |
| storage_context = StorageContext.from_defaults(persist_dir="storage") |
| index = load_index_from_storage(storage_context, index_id="vector_index") |
| print("Index loaded") |
|
|
| |
| def query_model_with_context(question): |
|
|
| retriever = index.as_retriever(similarity_top_k=3) |
| nodes = retriever.retrieve(question) |
|
|
| for node in nodes: |
| print(node) |
|
|
| |
| ranked_documents = reranker.compress_documents( |
| query=question, |
| documents = [LangDocument(page_content=node.text) for node in nodes] |
| ) |
|
|
| |
| print(f"Most relevant node: {ranked_documents[0].page_content}") |
|
|
| |
| context = ranked_documents[0].page_content |
|
|
| |
| messages = [ |
| ChatMessage(role=MessageRole.SYSTEM, content=context), |
| ChatMessage(role=MessageRole.USER, content=str(question)) |
| ] |
|
|
| completion = client.chat(messages) |
| |
| |
| response_text = "" |
|
|
| if isinstance(completion, (list, tuple)): |
| |
| response_text = ' '.join(completion) |
| elif isinstance(completion, str): |
| |
| response_text = completion |
| else: |
| |
| response_text = str(completion) |
| |
| response_text = response_text.replace("assistant:", "Final Response:").strip() |
|
|
| return response_text |
|
|
|
|
| |
| st.title("Chat with this Rerank RAG App") |
| question = st.text_input("Enter a relevant question to chat with the attached PhilDataset PDF file:") |
|
|
| if st.button("Submit"): |
| if question: |
| st.write("**RAG Response:**") |
| response = query_model_with_context(question) |
| st.write(response) |
| else: |
| st.warning("Please enter a question.") |
|
|