| import streamlit as st |
| import os |
| import dotenv |
| import uuid |
| import logging |
|
|
| |
| os.environ["HF_HOME"] = "/tmp/.cache/huggingface" |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface" |
|
|
| |
| os.makedirs("/tmp/.cache/huggingface", exist_ok=True) |
| os.makedirs("/tmp/chroma_persistent_db", exist_ok=True) |
| os.makedirs("/tmp/source_files", exist_ok=True) |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| from langchain.schema import HumanMessage, AIMessage |
| from langchain_groq import ChatGroq |
| from rag_methods import ( |
| load_doc_to_db, |
| load_url_to_db, |
| stream_llm_response, |
| stream_llm_rag_response, |
| ) |
|
|
| dotenv.load_dotenv() |
|
|
| |
| def apply_custom_css(): |
| st.markdown(""" |
| <style> |
| .main .block-container { |
| padding-top: 2rem; |
| padding-bottom: 2rem; |
| } |
| h1, h2, h3, h4 { |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
| font-weight: 600; |
| } |
| .app-title { |
| text-align: center; |
| color: #4361ee; |
| font-size: 2.2rem; |
| font-weight: 700; |
| margin-bottom: 1.5rem; |
| padding: 1rem; |
| border-radius: 10px; |
| background: linear-gradient(90deg, rgba(67, 97, 238, 0.1), rgba(58, 12, 163, 0.1)); |
| text-shadow: 0px 0px 2px rgba(0,0,0,0.1); |
| } |
| .chat-container { |
| border-radius: 10px; |
| padding: 10px; |
| margin-bottom: 1rem; |
| } |
| .message-container { |
| padding: 0.8rem; |
| margin-bottom: 0.8rem; |
| border-radius: 8px; |
| } |
| .user-message { |
| background-color: rgba(67, 97, 238, 0.15); |
| border-left: 4px solid #4361ee; |
| } |
| .assistant-message { |
| background-color: rgba(58, 12, 163, 0.1); |
| border-left: 4px solid #3a0ca3; |
| } |
| .document-list { |
| background-color: rgba(67, 97, 238, 0.05); |
| border-radius: 8px; |
| padding: 0.7rem; |
| } |
| .upload-container { |
| border: 2px dashed rgba(67, 97, 238, 0.5); |
| border-radius: 10px; |
| padding: 1rem; |
| margin-bottom: 1rem; |
| text-align: center; |
| } |
| .status-indicator { |
| font-size: 0.85rem; |
| font-weight: 600; |
| padding: 0.3rem 0.7rem; |
| border-radius: 20px; |
| display: inline-block; |
| margin-bottom: 0.5rem; |
| } |
| .status-active { |
| background-color: rgba(46, 196, 182, 0.2); |
| color: #2EC4B6; |
| } |
| .status-inactive { |
| background-color: rgba(231, 111, 81, 0.2); |
| color: #E76F51; |
| } |
| @media screen and (max-width: 768px) { |
| .app-title { |
| font-size: 1.8rem; |
| padding: 0.7rem; |
| } |
| } |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| |
| st.set_page_config( |
| page_title="RAG-Xpert: An Enhanced RAG Framework", |
| page_icon="π", |
| layout="centered", |
| initial_sidebar_state="expanded" |
| ) |
|
|
| apply_custom_css() |
|
|
| st.markdown('<h1 class="app-title">π RAG-Xpert: An Enhanced Retrieval-Augmented Generation Framework π€</h1>', unsafe_allow_html=True) |
|
|
| |
| if "session_id" not in st.session_state: |
| st.session_state.session_id = str(uuid.uuid4()) |
| if "rag_sources" not in st.session_state: |
| st.session_state.rag_sources = [] |
| if "messages" not in st.session_state: |
| st.session_state.messages = [ |
| {"role": "user", "content": "Hello"}, |
| {"role": "assistant", "content": "Hi there! How can I assist you today?"} |
| ] |
|
|
| |
| with st.sidebar: |
| st.markdown(""" |
| <div style=" |
| text-align: center; |
| padding: 1rem 0; |
| margin-bottom: 1.5rem; |
| background: linear-gradient(to right, #4361ee22, #3a0ca322); |
| border-radius: 10px;"> |
| <div style="font-size: 0.85rem; color: #888;">Developed By</div> |
| <div style="font-size: 1.2rem; font-weight: 700; color: #4361ee;">Uditanshu Pandey</div> |
| </div> |
| """, unsafe_allow_html=True) |
|
|
| is_vector_db_loaded = "vector_db" in st.session_state and st.session_state.vector_db is not None |
| rag_status = st.toggle("Enable Knowledge Enhancement (RAG)", value=is_vector_db_loaded, key="use_rag", disabled=not is_vector_db_loaded) |
|
|
| if rag_status: |
| st.markdown('<div class="status-indicator status-active">RAG Mode: Active β</div>', unsafe_allow_html=True) |
| else: |
| st.markdown('<div class="status-indicator status-inactive">RAG Mode: Inactive β</div>', unsafe_allow_html=True) |
|
|
| st.toggle("Show Retrieved Context", key="debug_mode", value=False) |
| st.button("π§Ή Clear Chat History", on_click=lambda: st.session_state.messages.clear(), type="primary") |
|
|
| st.markdown("<h3 style='text-align: center; color: #4361ee; margin-top: 1.5rem;'>π Knowledge Sources</h3>", unsafe_allow_html=True) |
| st.markdown('<div class="upload-container">', unsafe_allow_html=True) |
| st.file_uploader("π Upload Documents", type=["pdf", "txt", "docx", "md"], accept_multiple_files=True, on_change=load_doc_to_db, key="rag_docs") |
| st.markdown('</div>', unsafe_allow_html=True) |
|
|
| st.text_input("π Add Webpage URL", placeholder="https://example.com", on_change=load_url_to_db, key="rag_url") |
|
|
| doc_count = len(st.session_state.rag_sources) if is_vector_db_loaded else 0 |
| with st.expander(f"π Knowledge Base ({doc_count} sources)"): |
| if doc_count: |
| st.markdown('<div class="document-list">', unsafe_allow_html=True) |
| for i, source in enumerate(st.session_state.rag_sources): |
| st.markdown(f"**{i+1}.** {source}") |
| st.markdown('</div>', unsafe_allow_html=True) |
| else: |
| st.info("No documents added yet. Upload files or add URLs to enhance the assistant's knowledge.") |
|
|
| |
| llm_stream = ChatGroq( |
| model_name="meta-llama/llama-4-scout-17b-16e-instruct", |
| api_key=os.getenv("GROQ_API_KEY"), |
| temperature=0.4, |
| max_tokens=1024, |
| ) |
|
|
| |
| st.markdown('<div class="chat-container">', unsafe_allow_html=True) |
| for message in st.session_state.messages: |
| avatar = "π€" if message["role"] == "user" else "π€" |
| css_class = "user-message" if message["role"] == "user" else "assistant-message" |
| with st.chat_message(message["role"], avatar=avatar): |
| st.markdown(f'<div class="message-container {css_class}">{message["content"]}</div>', unsafe_allow_html=True) |
| st.markdown('</div>', unsafe_allow_html=True) |
|
|
| |
| if prompt := st.chat_input("Ask me anything..."): |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| with st.chat_message("user", avatar="π€"): |
| st.markdown(f'<div class="message-container user-message">{prompt}</div>', unsafe_allow_html=True) |
|
|
| with st.chat_message("assistant", avatar="π€"): |
| thinking_placeholder = st.empty() |
| thinking_placeholder.info("Thinking... Please wait a moment.") |
| messages = [ |
| HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"]) |
| for m in st.session_state.messages |
| ] |
| if not st.session_state.use_rag: |
| thinking_placeholder.empty() |
| st.write_stream(stream_llm_response(llm_stream, messages)) |
| else: |
| thinking_placeholder.info("Searching knowledge base... Please wait.") |
| st.write_stream(stream_llm_rag_response(llm_stream, messages)) |