Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Dict, Tuple, Optional, Any | |
| import streamlit as st | |
| import logging | |
| from datetime import datetime | |
| # Disable telemetry for LangChain and Chroma by default | |
| os.environ.setdefault("LANGCHAIN_TELEMETRY_ENABLED", "false") | |
| os.environ.setdefault("LANGCHAIN_DISABLE_TELEMETRY", "true") | |
| os.environ.setdefault("CHROMA_TELEMETRY_ENABLED", "false") | |
| from src.utils.rag_runtime import ( | |
| run_ingest_cli, | |
| build_or_load_retriever_cached, | |
| get_chain_cached, | |
| answer_with_kg, | |
| ) | |
| from src.utils.metrics import compute_quality_scores | |
| from src.utils.formatting import format_source_label | |
| from src.utils.env import ensure_openai_key | |
| class AbaloneRAGApp: | |
| """Main application class for the Abalone RAG Chatbot.""" | |
| def __init__(self) -> None: | |
| """Initialize the Streamlit page and application state.""" | |
| st.set_page_config(page_title="Abalone RAG Chatbot", page_icon="🐚") | |
| # Header row: title/subtitle on the left, rebuild action on the right | |
| header_col, action_col = st.columns([5, 1]) | |
| with header_col: | |
| st.title("Abalone RAG Chatbot") | |
| st.write( | |
| "Ask natural-language questions about abalone biology, ecology, " | |
| "and research datasets. The app uses a local Chroma vectorstore " | |
| "and OpenAI to retrieve and answer questions accurately." | |
| ) | |
| with action_col: | |
| # A compact, prominent rebuild control placed in the header | |
| self._top_rebuild_clicked = st.button( | |
| "Rebuild vectorstore", | |
| key="top_rebuild", | |
| use_container_width=True, | |
| ) | |
| # Data and vectorstore locations | |
| self.data_dir = "./data" | |
| self.persist_dir = "./vectorstore" | |
| # Initialize session state | |
| st.session_state.setdefault("chat_history", []) | |
| st.session_state.setdefault("rebuild_pending", False) | |
| self.chat_history: List[Dict] = st.session_state["chat_history"] | |
| # Sidebar configuration | |
| ( | |
| self.model_name, | |
| self.top_k, | |
| self.retrieval_mode, | |
| self.temperature, | |
| self.answer_length, | |
| self.style_instruction, | |
| self.use_kg, | |
| self.kg_hops, | |
| ) = self._build_sidebar() | |
| # Ensure rebuild_clicked reflects the top-right control | |
| self.rebuild_clicked = bool(getattr(self, "_top_rebuild_clicked", False)) | |
| # QA chain instance (loaded lazily) | |
| # typing as Any avoids static warnings when calling the chain object | |
| self.chain: Optional[Any] = None | |
| # ------------------------------------------------------------------ | |
| # Sidebar configuration | |
| # ------------------------------------------------------------------ | |
| def _build_sidebar(self) -> Tuple[str, int, str, float, str, str, bool, int]: | |
| """Render all sidebar controls and return model configuration. | |
| Returns: | |
| Tuple containing: | |
| - model_name: Which LLM to use. | |
| - top_k: Number of chunks to retrieve. | |
| - retrieval_mode: Strategy (mmr, similarity, hybrid). | |
| - temperature: LLM temperature. | |
| - answer_length: Short/Medium/Long preference. | |
| - style_instruction: Natural-language style directive. | |
| - rebuild_clicked: Whether "Rebuild vectorstore" was pressed. | |
| """ | |
| st.sidebar.header("Model Settings") | |
| model_name = st.sidebar.selectbox( | |
| "Model", | |
| options=["gpt-3.5-turbo", "gpt-4"], | |
| index=0, | |
| ) | |
| st.sidebar.markdown("---") | |
| # Retrieval configuration | |
| st.sidebar.header("Retrieval Configuration") | |
| top_k = st.sidebar.slider( | |
| "Number of retrieved chunks (k)", | |
| min_value=2, | |
| max_value=10, | |
| value=4, | |
| ) | |
| retrieval_mode_label = st.sidebar.selectbox( | |
| "Retrieval mode", | |
| ["MMR (diverse)", "Similarity", "Hybrid (dense + MMR)"], | |
| index=2, | |
| ) | |
| retrieval_mode_map = { | |
| "MMR (diverse)": "mmr", | |
| "Similarity": "similarity", | |
| "Hybrid (dense + MMR)": "hybrid", | |
| } | |
| retrieval_mode = retrieval_mode_map[retrieval_mode_label] | |
| # Knowledge graph toggle (placed under Retrieval Configuration) | |
| st.sidebar.markdown("---") | |
| st.sidebar.header("Knowledge Graph") | |
| use_kg = st.sidebar.checkbox("Use knowledge graph for retrieval", value=False) | |
| kg_hops = st.sidebar.slider("KG hops", min_value=1, max_value=3, value=1) | |
| st.sidebar.markdown("---") | |
| # Answer style | |
| st.sidebar.header("Answer Style") | |
| temperature = st.sidebar.slider( | |
| "Temperature", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.2, | |
| step=0.05, | |
| ) | |
| answer_length = st.sidebar.selectbox( | |
| "Answer length", | |
| ["Short", "Medium", "Long"], | |
| index=1, | |
| ) | |
| # (Vectorstore rebuild moved to top-right action button) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("<small>To rebuild the vectorstore use the top-right \"Rebuild vectorstore\" button.</small>", unsafe_allow_html=True) | |
| # Build style instruction for the LLM | |
| length_instruction_map = { | |
| "Short": "Answer in 1–3 sentences.", | |
| "Medium": "Answer in 1–2 paragraphs.", | |
| "Long": "Provide a detailed, multi-paragraph explanation.", | |
| } | |
| length_instruction = length_instruction_map[answer_length] | |
| style_instruction = ( | |
| length_instruction | |
| + f" Use a response style appropriate for a temperature of {temperature:.2f}, " | |
| "where lower values are more factual and higher values are more exploratory." | |
| ) | |
| return ( | |
| model_name, | |
| top_k, | |
| retrieval_mode, | |
| temperature, | |
| answer_length, | |
| style_instruction, | |
| use_kg, | |
| kg_hops, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Vectorstore rebuild workflow | |
| # ------------------------------------------------------------------ | |
| def handle_rebuild(self) -> None: | |
| """Render rebuild confirmation dialog and rebuild if confirmed. | |
| This manages the 2-step rebuild process: | |
| 1. User clicks "Rebuild vectorstore". | |
| 2. A confirmation dialog appears with "Yes, rebuild" and "Cancel". | |
| If confirmed, the vectorstore is regenerated and caches are cleared. | |
| """ | |
| if self.rebuild_clicked: | |
| st.session_state["rebuild_pending"] = True | |
| if not st.session_state["rebuild_pending"]: | |
| return | |
| st.warning( | |
| "Rebuild the vectorstore from the current contents of ./data? " | |
| "This will overwrite existing embeddings." | |
| ) | |
| col_left, col_center, col_right = st.columns([1, 2, 1]) | |
| with col_center: | |
| confirm = st.button( | |
| "Yes, rebuild", | |
| key="confirm_rebuild", | |
| use_container_width=True, | |
| ) | |
| cancel = st.button( | |
| "Cancel", | |
| key="cancel_rebuild", | |
| use_container_width=True, | |
| ) | |
| # Centered green (confirm) and red (cancel) buttons | |
| st.markdown( | |
| """ | |
| <style> | |
| div[data-testid="column"] div:has(> button[aria-label="Yes, rebuild"]) button { | |
| background-color: #27ae60 !important; | |
| color: white !important; | |
| } | |
| div[data-testid="column"] div:has(> button[aria-label="Cancel"]) button { | |
| background-color: #c0392b !important; | |
| color: white !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # add a small UI log for rebuild actions | |
| def _ui_log(msg: str): | |
| try: | |
| os.makedirs(self.persist_dir, exist_ok=True) | |
| with open(os.path.join(self.persist_dir, "ui_rebuild.log"), "a", encoding="utf-8") as fh: | |
| fh.write(f"{msg}\n") | |
| except Exception: | |
| pass | |
| if confirm: | |
| _ui_log(f"{datetime.utcnow().isoformat()} - Confirm rebuild clicked by user") | |
| with st.spinner("Rebuilding vectorstore..."): | |
| try: | |
| out = run_ingest_cli(data_dir=self.data_dir, persist_dir=self.persist_dir) | |
| _ui_log(f"{datetime.utcnow().isoformat()} - Rebuild succeeded") | |
| except Exception as e: | |
| import subprocess as _sp | |
| _ui_log(f"{datetime.utcnow().isoformat()} - Rebuild failed: {e}") | |
| if isinstance(e, _sp.CalledProcessError): | |
| stderr = getattr(e, 'stderr', None) | |
| stdout = getattr(e, 'output', None) or getattr(e, 'stdout', None) | |
| st.error("Rebuild failed. See logs below.") | |
| if stdout: | |
| st.markdown("**ingest stdout:**") | |
| st.code(stdout) | |
| if stderr: | |
| st.markdown("**ingest stderr:**") | |
| st.code(stderr) | |
| else: | |
| st.error(f"Rebuild failed: {e}") | |
| st.session_state["rebuild_pending"] = False | |
| return | |
| # On success, clear cached retriever/chain and reload | |
| try: | |
| build_or_load_retriever_cached.clear() | |
| get_chain_cached.clear() | |
| except Exception: | |
| # if clearing cache fails, just log it in UI log | |
| _ui_log(f"{datetime.utcnow().isoformat()} - Warning: failed to clear cached functions") | |
| self.chain = get_chain_cached( | |
| model_name=self.model_name, | |
| top_k=self.top_k, | |
| retrieval_mode=self.retrieval_mode, | |
| data_dir=self.data_dir, | |
| persist_dir=self.persist_dir, | |
| ) | |
| st.session_state["rebuild_pending"] = False | |
| st.success("Vectorstore rebuilt successfully.") | |
| elif cancel: | |
| st.session_state["rebuild_pending"] = False | |
| st.info("Rebuild canceled.") | |
| # ------------------------------------------------------------------ | |
| # Chain loading | |
| # ------------------------------------------------------------------ | |
| def ensure_chain_ready(self) -> None: | |
| """Load or create the QA chain unless a rebuild is still pending.""" | |
| if st.session_state["rebuild_pending"]: | |
| return | |
| if self.chain is None: | |
| with st.spinner("Initializing knowledge base and chat model..."): | |
| self.chain = get_chain_cached( | |
| model_name=self.model_name, | |
| top_k=self.top_k, | |
| retrieval_mode=self.retrieval_mode, | |
| data_dir=self.data_dir, | |
| persist_dir=self.persist_dir, | |
| ) | |
| st.success("Knowledge base and model are ready.") | |
| else: | |
| st.success("Knowledge base and model are ready.") | |
| # ------------------------------------------------------------------ | |
| # Chat UI | |
| # ------------------------------------------------------------------ | |
| def render_chat_history(self) -> None: | |
| """Render previous user and assistant messages.""" | |
| for turn in self.chat_history: | |
| with st.chat_message("user"): | |
| st.markdown(turn["question"]) | |
| with st.chat_message("assistant"): | |
| st.markdown(turn["answer"]) | |
| def handle_user_input(self) -> None: | |
| """Process new user queries, run RAG, compute metrics, and display results.""" | |
| if st.session_state["rebuild_pending"] or self.chain is None: | |
| return | |
| user_input = st.chat_input( | |
| "Ask a question about abalone (biology, data, methodology, etc.)" | |
| ) | |
| if not user_input: | |
| return | |
| # Render user message | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| # Run inference | |
| with st.spinner("Thinking..."): | |
| prior_history: List[Tuple[str, str]] = [ | |
| (h.get("question"), h.get("answer", "")) | |
| for h in self.chat_history | |
| ] | |
| styled_question = self.style_instruction + "\n\nQuestion: " + user_input | |
| if self.chain is None: | |
| st.error("Model not initialized. Please wait for the knowledge base and model to be ready or rebuild the vectorstore.") | |
| return | |
| # Call the chain with a safe retry: if the underlying vectorstore is corrupted or missing | |
| # (for example, Chroma raises an internal HNSW/disk error), attempt one automatic rebuild | |
| # and retry. This avoids crashing the Streamlit app in deployed environments. | |
| attempted_rebuild = False | |
| last_exception = None | |
| while True: | |
| try: | |
| if getattr(self, 'use_kg', False): | |
| result = answer_with_kg( | |
| self.chain, | |
| styled_question, | |
| prior_history, | |
| persist_dir=self.persist_dir, | |
| kg_hops=self.kg_hops, | |
| ) | |
| else: | |
| result = self.chain({"question": styled_question, "chat_history": prior_history}) | |
| break | |
| except Exception as e: | |
| # Keep the exception for logging and potential re-raise after a failed retry | |
| last_exception = e | |
| # If we've already attempted a rebuild, give up and show an error | |
| if attempted_rebuild: | |
| st.error("Retrieval error: failed to query the knowledge base. Try rebuilding the vectorstore manually.") | |
| # Optionally show the exception text for debugging | |
| st.exception(e) | |
| # Stop processing this user input | |
| return | |
| # Attempt an automatic rebuild and retry once | |
| attempted_rebuild = True | |
| st.warning("Detected retrieval backend issue — attempting to rebuild the vectorstore and retry...") | |
| try: | |
| run_ingest_cli(data_dir=self.data_dir, persist_dir=self.persist_dir) | |
| except Exception as rebuild_err: | |
| st.error("Automatic rebuild failed; please rebuild manually from the sidebar or CLI.") | |
| st.exception(rebuild_err) | |
| return | |
| # Clear cached retriever and chain and reload | |
| try: | |
| build_or_load_retriever_cached.clear() | |
| get_chain_cached.clear() | |
| self.chain = get_chain_cached( | |
| model_name=self.model_name, | |
| top_k=self.top_k, | |
| retrieval_mode=self.retrieval_mode, | |
| data_dir=self.data_dir, | |
| persist_dir=self.persist_dir, | |
| ) | |
| except Exception as reload_err: | |
| st.error("Failed to reload the QA chain after rebuilding the vectorstore.") | |
| st.exception(reload_err) | |
| return | |
| # loop will retry once | |
| answer = ( | |
| result.get("answer") | |
| or result.get("result") | |
| or result.get("output_text") | |
| or "" | |
| ) | |
| source_docs = result.get("source_documents") or [] | |
| # Normalize retrieved docs for UI and metrics | |
| formatted_sources: List[Dict] = [] | |
| for idx, sd in enumerate(source_docs, start=1): | |
| if isinstance(sd, dict): | |
| meta = sd.get("metadata", {}) or {} | |
| text = ( | |
| sd.get("page_content") | |
| or sd.get("content") | |
| or sd.get("text", "") | |
| or "" | |
| ) | |
| else: | |
| meta = getattr(sd, "metadata", {}) or {} | |
| text = ( | |
| getattr(sd, "page_content", None) | |
| or getattr(sd, "content", "") | |
| or "" | |
| ) | |
| formatted_sources.append( | |
| {"index": idx, "metadata": meta, "content": str(text)} | |
| ) | |
| # Compute simple retrieval quality metrics | |
| coverage, grounding = compute_quality_scores( | |
| user_input, answer, formatted_sources | |
| ) | |
| coverage_pct = int(round(coverage * 100)) | |
| grounding_pct = int(round(grounding * 100)) | |
| # Render assistant message + debug block | |
| with st.chat_message("assistant"): | |
| st.markdown(answer) | |
| with st.expander("Retrieval Metrics and Sources"): | |
| st.markdown(f"- Retrieval mode: `{self.retrieval_mode}`") | |
| st.markdown(f"- k: `{self.top_k}`") | |
| st.markdown( | |
| f"- Coverage score (question vs sources): **{coverage_pct}%**" | |
| ) | |
| st.markdown( | |
| f"- Grounding score (answer vs sources): **{grounding_pct}%**" | |
| ) | |
| if formatted_sources: | |
| st.markdown("**Retrieved chunks:**") | |
| for src in formatted_sources: | |
| label = format_source_label(src["metadata"], src["index"]) | |
| snippet = src["content"][:200].replace("\n", " ") | |
| st.markdown(f"**[{src['index']}] {label}**") | |
| st.code(snippet + "...") | |
| # Persist turn in chat history | |
| self.chat_history.append( | |
| { | |
| "question": user_input, | |
| "answer": answer, | |
| "sources": formatted_sources, | |
| } | |
| ) | |
| st.session_state["chat_history"] = self.chat_history | |
| def main() -> None: | |
| """Main entry point for running the Abalone RAG Chatbot app.""" | |
| app = AbaloneRAGApp() | |
| # Allow rebuild actions before enforcing OPENAI key so users can inspect logs | |
| # and trigger rebuild operations even when the key isn't set. Chain init | |
| # requires the key, so enforce it after handling rebuild requests. | |
| app.handle_rebuild() | |
| if not ensure_openai_key(): | |
| st.stop() | |
| app.ensure_chain_ready() | |
| app.render_chat_history() | |
| app.handle_user_input() | |
| if __name__ == "__main__": | |
| main() | |