Spaces:
Sleeping
Sleeping
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.documents import Document | |
| from langchain_core.runnables import ( | |
| RunnableParallel, | |
| RunnablePassthrough, | |
| RunnableLambda, | |
| ) | |
| from typing import List | |
| import os | |
| from datetime import datetime, timedelta | |
| import json | |
| from pathlib import Path | |
| # Fix tokenizer warning | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| class RAGPipeline: | |
| # Model configuration for multi-provider support | |
| MODEL_CONFIG = { | |
| "gpt-oss-120b": { | |
| "provider": "groq", | |
| "model": "openai/gpt-oss-120b", | |
| "display": "GPT-OSS 120B (OpenAI)", | |
| "temperature": 0.1, | |
| "max_tokens": 1024, | |
| }, | |
| "llama-3.3-70b": { | |
| "provider": "groq", | |
| "model": "llama-3.3-70b-versatile", | |
| "display": "Llama 3.3 70B (Meta)", | |
| "temperature": 0.1, | |
| "max_tokens": 1024, | |
| }, | |
| "gemma-3-27b": { | |
| "provider": "openrouter", | |
| "model": "google/gemma-3-27b-it:free", | |
| "display": "Gemma 3 27B (Google)", | |
| "temperature": 0.1, | |
| "max_tokens": 512, | |
| }, | |
| } | |
| def __init__( | |
| self, | |
| persist_directory: str = "./data/chroma_db", | |
| default_model: str = "gpt-oss-120b", | |
| ): | |
| """ | |
| Initialize RAG pipeline with embeddings, vector store, and multi-provider LLM support. | |
| Sets up rate limiting (10 queries/hour) and supports Groq + OpenRouter APIs. | |
| Args: | |
| persist_directory: Path to store ChromaDB vector database (default: ./data/chroma_db) | |
| default_model: Model key from MODEL_CONFIG (default: gpt-oss-120b) | |
| """ | |
| # Initialize better embeddings (BAAI/bge-small-en-v1.5) | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="BAAI/bge-small-en-v1.5", | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True}, # Important for bge models | |
| ) | |
| # Initialize vector store | |
| self.vector_store = Chroma( | |
| persist_directory=persist_directory, | |
| embedding_function=self.embeddings, | |
| ) | |
| # Rate limiting setup (10 queries per hour) | |
| self.rate_limit_file = Path("./data/rate_limit.json") | |
| self.rate_limit_file.parent.mkdir(parents=True, exist_ok=True) | |
| # Document tracking for auto-cleanup (7-day retention) | |
| self.doc_metadata_file = Path("./data/document_metadata.json") | |
| self.doc_metadata_file.parent.mkdir(parents=True, exist_ok=True) | |
| # Auto-cleanup on initialization | |
| self._cleanup_old_documents() | |
| # Initialize LLM with default model | |
| self.current_model = default_model | |
| self.llm = self._initialize_llm(default_model) | |
| # Current session ID for retrieval filtering (set per-query) | |
| self._current_session_id = None | |
| # Create RAG chain | |
| self.rag_chain = self.create_rag_chain() | |
| def _initialize_llm(self, model_key: str): | |
| """ | |
| Initialize LLM based on provider and model configuration. | |
| Supports both Groq and OpenRouter providers. | |
| Args: | |
| model_key: Key from MODEL_CONFIG dictionary | |
| Returns: | |
| ChatOpenAI: Configured LLM instance | |
| Raises: | |
| ValueError: If model_key is invalid or required API key is missing | |
| """ | |
| if model_key not in self.MODEL_CONFIG: | |
| raise ValueError( | |
| f"Invalid model key: {model_key}. " | |
| f"Available models: {', '.join(self.MODEL_CONFIG.keys())}" | |
| ) | |
| config = self.MODEL_CONFIG[model_key] | |
| provider = config["provider"] | |
| if provider == "groq": | |
| # Groq API configuration | |
| groq_key = os.getenv("GROQ_API_KEY") | |
| if not groq_key: | |
| raise ValueError( | |
| "GROQ_API_KEY environment variable not set. " | |
| "Get one free at https://console.groq.com/keys" | |
| ) | |
| return ChatOpenAI( | |
| model=config["model"], | |
| openai_api_key=groq_key, | |
| openai_api_base="https://api.groq.com/openai/v1", | |
| temperature=config["temperature"], | |
| max_tokens=config["max_tokens"], | |
| ) | |
| elif provider == "openrouter": | |
| # OpenRouter API configuration | |
| openrouter_key = os.getenv("OPENROUTER_API_KEY") | |
| if not openrouter_key: | |
| raise ValueError( | |
| "OPENROUTER_API_KEY environment variable not set. " | |
| "Get one free at https://openrouter.ai/keys" | |
| ) | |
| return ChatOpenAI( | |
| model=config["model"], | |
| openai_api_key=openrouter_key, | |
| openai_api_base="https://openrouter.ai/api/v1", | |
| temperature=config["temperature"], | |
| max_tokens=config["max_tokens"], | |
| ) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| def switch_model(self, model_key: str) -> str: | |
| """ | |
| Dynamically switch to a different LLM model and recreate the RAG chain. | |
| Args: | |
| model_key: Key from MODEL_CONFIG dictionary | |
| Returns: | |
| str: Display name of the switched model | |
| Raises: | |
| ValueError: If model_key is invalid or API key is missing | |
| """ | |
| # Initialize new LLM | |
| self.llm = self._initialize_llm(model_key) | |
| self.current_model = model_key | |
| # Recreate RAG chain with new LLM | |
| self.rag_chain = self.create_rag_chain() | |
| return self.MODEL_CONFIG[model_key]["display"] | |
| def create_rag_chain(self): | |
| """ | |
| Creates the RAG chain by combining retriever, prompt template, and LLM. | |
| Returns: | |
| RunnableParallel: Chain that retrieves context and generates answers | |
| """ | |
| prompt = PromptTemplate( | |
| input_variables=["context", "sources", "question"], | |
| template="""You are an expert AI assistant specializing in document analysis. Your goal is to provide comprehensive, accurate, and well-cited answers. | |
| Available Documents: {sources} | |
| Context from Documents: | |
| {context} | |
| User Question: {question} | |
| INSTRUCTIONS FOR YOUR RESPONSE: | |
| 1. **Analyze Thoroughly**: Read the context carefully and identify all relevant information | |
| 2. **Answer Comprehensively**: Provide a complete, detailed answer that fully addresses the question | |
| 3. **Use Proper Structure**: | |
| - Start with a clear, direct answer | |
| - Follow with supporting details and explanation | |
| - Use markdown formatting (headings, bullet points, bold) for readability | |
| 4. **Cite Sources Inline**: As you make specific claims, cite the source immediately | |
| - Format: (Source: filename, Page X) or (Source: filename) if page unknown | |
| - Example: "The termination period is 30 days (Source: service_agreement.pdf, Page 3)" | |
| - Be specific about which document and page number whenever possible | |
| 5. **Include a Sources Section**: At the end of your answer, add: | |
| **Sources Referenced:** | |
| • filename (Page X) - Brief note about what info came from here | |
| • filename2 (Page Y) - Brief note | |
| 6. **Quality Standards**: | |
| - Be specific and precise with facts, numbers, dates, and terms | |
| - Quote exact phrases when important (use quotation marks) | |
| - If information is unclear or missing, state what's uncertain | |
| - Connect related points to create a cohesive narrative | |
| Answer:""", | |
| ) | |
| retriever = self.vector_store.as_retriever( | |
| search_kwargs={"k": 4} # Retrieve top 4 most relevant chunks | |
| ) | |
| # Wrap retriever to filter by session | |
| def session_filter(docs): | |
| """Filter documents by current session.""" | |
| session_id = self._current_session_id | |
| if session_id: | |
| # Return docs matching session_id OR sample docs (is_sample=True) | |
| return [ | |
| d | |
| for d in docs | |
| if d.metadata.get("session_id") == session_id | |
| or d.metadata.get("is_sample", False) | |
| ] | |
| return docs | |
| # Create session-filtered retriever as a Runnable | |
| session_filtered_retriever = retriever | RunnableLambda(session_filter) | |
| rag_chain = RunnableParallel( | |
| { | |
| "result": ( | |
| { | |
| "context": session_filtered_retriever | |
| | (lambda docs: "\n\n".join([d.page_content for d in docs])), | |
| "sources": session_filtered_retriever | |
| | ( | |
| lambda docs: ", ".join( | |
| list( | |
| set( | |
| [ | |
| d.metadata.get("source", "").split("/")[-1] | |
| for d in docs | |
| ] | |
| ) | |
| ) | |
| ) | |
| ), | |
| "question": RunnablePassthrough(), | |
| } | |
| | prompt | |
| | self.llm | |
| ), | |
| "source_documents": session_filtered_retriever, | |
| } | |
| ) | |
| return rag_chain | |
| def add_documents( | |
| self, | |
| documents: List[Document], | |
| session_id: str = None, | |
| is_sample: bool = False, | |
| ) -> None: | |
| """ | |
| Add processed document chunks to the vector store for retrieval. | |
| Adds session_id and timestamp metadata for isolation and auto-cleanup. | |
| Args: | |
| documents: List of Document objects with text and metadata | |
| session_id: User's session ID for isolation (None for samples) | |
| is_sample: If True, document is global and won't be auto-deleted | |
| """ | |
| # Add session and timestamp metadata to each chunk | |
| now = datetime.now().isoformat() | |
| for doc in documents: | |
| doc.metadata["session_id"] = session_id if not is_sample else "global" | |
| doc.metadata["uploaded_at"] = now | |
| doc.metadata["is_sample"] = is_sample | |
| self.vector_store.add_documents(documents) | |
| # Track document metadata for cleanup (skip samples) | |
| if not is_sample and documents: | |
| self._track_document( | |
| documents[0].metadata.get("source", "unknown"), | |
| session_id=session_id, | |
| ) | |
| def _check_rate_limit(self) -> bool: | |
| """ | |
| Enforces rate limit of 10 queries per hour by tracking query timestamps. | |
| Returns: | |
| bool: True if within limit, False if exceeded | |
| """ | |
| now = datetime.now() | |
| # Load existing queries if file exists | |
| if self.rate_limit_file.exists(): | |
| try: | |
| with open(self.rate_limit_file, "r") as f: | |
| content = f.read().strip() | |
| if content: # Only parse if file is not empty | |
| data = json.loads(content) | |
| queries = [ | |
| datetime.fromisoformat(q) for q in data.get("queries", []) | |
| ] | |
| else: | |
| queries = [] | |
| except (json.JSONDecodeError, ValueError): | |
| # If file is corrupted, start fresh | |
| queries = [] | |
| else: | |
| queries = [] | |
| # Remove queries older than 1 hour | |
| one_hour_ago = now - timedelta(hours=1) | |
| recent_queries = [q for q in queries if q > one_hour_ago] | |
| # Check limit | |
| if len(recent_queries) >= 10: | |
| return False | |
| # Add current query | |
| recent_queries.append(now) | |
| # Save updated queries | |
| with open(self.rate_limit_file, "w") as f: | |
| json.dump({"queries": [q.isoformat() for q in recent_queries]}, f) | |
| return True | |
| def query(self, question: str, session_id: str = None): | |
| """ | |
| Query the RAG system with a question, retrieves relevant context and generates answer. | |
| Results are filtered to the user's session documents + global samples. | |
| Args: | |
| question: User's question string | |
| session_id: User's session ID for filtering results | |
| Returns: | |
| dict: { | |
| "answer": str, | |
| "citations": List[dict], | |
| "num_sources": int | |
| } | |
| Raises: | |
| ValueError: If rate limit (10 queries/hour) is exceeded | |
| """ | |
| # Check rate limit | |
| if not self._check_rate_limit(): | |
| raise ValueError( | |
| "Rate limit exceeded. You can only ask 10 questions per hour. " | |
| "Please try again later." | |
| ) | |
| # Set session ID for filtered retrieval | |
| self._current_session_id = session_id | |
| answer = self.rag_chain.invoke(question) | |
| result = answer["result"] | |
| # Extract answer text | |
| if hasattr(result, "content"): | |
| answer_text = result.content | |
| elif hasattr(result, "text"): | |
| answer_text = result.text | |
| else: | |
| answer_text = str(result) | |
| # Check if answer is empty | |
| if not answer_text or answer_text.strip() == "": | |
| answer_text = "I apologize, but I couldn't generate a response. Please try rephrasing your question." | |
| return {"answer": answer_text} | |
| def query_stream(self, question: str, session_id: str = None): | |
| """ | |
| Stream answer tokens for real-time display. | |
| Yields tokens as they arrive from the LLM. | |
| Args: | |
| question: User's question string | |
| session_id: User's session ID for filtering results | |
| Yields: | |
| str: Accumulated answer text (each yield contains full answer so far) | |
| """ | |
| # Check rate limit | |
| if not self._check_rate_limit(): | |
| yield "⚠️ Rate limit exceeded. You can only ask 10 questions per hour. Please try again later." | |
| return | |
| # Set session ID for filtered retrieval | |
| self._current_session_id = session_id | |
| # Get documents using retriever (non-streaming part) | |
| retriever = self.vector_store.as_retriever(search_kwargs={"k": 4}) | |
| docs = retriever.invoke(question) | |
| # Filter by session | |
| if session_id: | |
| docs = [ | |
| d | |
| for d in docs | |
| if d.metadata.get("session_id") == session_id | |
| or d.metadata.get("is_sample", False) | |
| ] | |
| if not docs: | |
| yield "I couldn't find relevant information in your documents. Please try rephrasing your question." | |
| return | |
| # Build context and sources | |
| context = "\n\n".join([d.page_content for d in docs]) | |
| sources = ", ".join( | |
| list(set([d.metadata.get("source", "").split("/")[-1] for d in docs])) | |
| ) | |
| # Format prompt | |
| prompt = self._format_prompt(context, sources, question) | |
| # Stream from LLM | |
| full_answer = "" | |
| for chunk in self.llm.stream(prompt): | |
| if hasattr(chunk, "content"): | |
| full_answer += chunk.content | |
| else: | |
| full_answer += str(chunk) | |
| yield full_answer | |
| def _format_prompt(self, context: str, sources: str, question: str) -> str: | |
| """ | |
| Format the RAG prompt with context, sources, and question. | |
| Args: | |
| context: Retrieved document content | |
| sources: Comma-separated source filenames | |
| question: User's question | |
| Returns: | |
| str: Formatted prompt string | |
| """ | |
| return f"""You are an expert AI assistant specializing in document analysis. Your goal is to provide comprehensive, accurate, and well-cited answers. | |
| Available Documents: {sources} | |
| Context from Documents: | |
| {context} | |
| User Question: {question} | |
| INSTRUCTIONS FOR YOUR RESPONSE: | |
| 1. **Analyze Thoroughly**: Read the context carefully and identify all relevant information | |
| 2. **Answer Comprehensively**: Provide a complete, detailed answer that fully addresses the question | |
| 3. **Use Proper Structure**: | |
| - Start with a clear, direct answer | |
| - Follow with supporting details and explanation | |
| - Use markdown formatting (headings, bullet points, bold) for readability | |
| 4. **Cite Sources Inline**: As you make specific claims, cite the source immediately | |
| - Format: (Source: filename, Page X) or (Source: filename) if page unknown | |
| - Example: "The termination period is 30 days (Source: service_agreement.pdf, Page 3)" | |
| - Be specific about which document and page number whenever possible | |
| 5. **Include a Sources Section**: At the end of your answer, add: | |
| **Sources Referenced:** | |
| • filename (Page X) - Brief note about what info came from here | |
| • filename2 (Page Y) - Brief note | |
| 6. **Quality Standards**: | |
| - Be specific and precise with facts, numbers, dates, and terms | |
| - Quote exact phrases when important (use quotation marks) | |
| - If information is unclear or missing, state what's uncertain | |
| - Connect related points to create a cohesive narrative | |
| Answer:""" | |
| def _extract_citations(self, source_documents: List[Document]) -> List[dict]: | |
| """ | |
| Extract formatted citations from source documents with page numbers and previews. | |
| Args: | |
| source_documents: List of retrieved Document objects from RAG chain | |
| Returns: | |
| List[dict]: Formatted citations with id, source, page, and preview | |
| """ | |
| import re | |
| citations = [] | |
| for idx, doc in enumerate(source_documents, 1): | |
| # Extract file name (basename only) | |
| source_path = doc.metadata.get("source", "Unknown") | |
| file_name = ( | |
| source_path.split("/")[-1] if "/" in source_path else source_path | |
| ) | |
| # Parse page number from content (PDF format: "---- Page X ----") | |
| page_num = None | |
| content = doc.page_content | |
| # Try direct metadata first | |
| if "page" in doc.metadata: | |
| page_num = str(doc.metadata["page"]) | |
| # Fallback: parse from content markers | |
| elif "---- Page " in content: | |
| match = re.search(r"---- Page (\d+) ----", content) | |
| if match: | |
| page_num = match.group(1) | |
| # Get clean preview (remove page markers) | |
| preview = re.sub(r"---- Page \d+ ----", "", content).strip() | |
| # Take first 150 chars for preview | |
| if len(preview) > 150: | |
| preview = preview[:150] + "..." | |
| citations.append( | |
| { | |
| "id": idx, | |
| "source": file_name, | |
| "page": page_num, | |
| "preview": preview, | |
| "full_content": content, | |
| } | |
| ) | |
| return citations | |
| def _track_document(self, source_path: str, session_id: str = None) -> None: | |
| """ | |
| Track document upload timestamp for auto-cleanup. | |
| Args: | |
| source_path: Path to the uploaded document | |
| session_id: User's session ID for the document | |
| """ | |
| # Load existing metadata | |
| if self.doc_metadata_file.exists(): | |
| with open(self.doc_metadata_file, "r") as f: | |
| metadata = json.load(f) | |
| else: | |
| metadata = {"documents": {}} | |
| # Add new document with current timestamp and session | |
| metadata["documents"][source_path] = { | |
| "uploaded_at": datetime.now().isoformat(), | |
| "session_id": session_id, | |
| "is_sample": False, | |
| } | |
| # Save updated metadata | |
| with open(self.doc_metadata_file, "w") as f: | |
| json.dump(metadata, f, indent=2) | |
| def _cleanup_old_documents(self) -> None: | |
| """ | |
| Remove documents older than 7 days from vector store. | |
| Sample documents are never deleted. | |
| """ | |
| if not self.doc_metadata_file.exists(): | |
| return | |
| with open(self.doc_metadata_file, "r") as f: | |
| metadata = json.load(f) | |
| now = datetime.now() | |
| seven_days_ago = now - timedelta(days=7) | |
| documents_to_keep = {} | |
| deleted_count = 0 | |
| for doc_path, doc_info in metadata.get("documents", {}).items(): | |
| upload_time = datetime.fromisoformat(doc_info["uploaded_at"]) | |
| # Keep if uploaded within 7 days OR is a sample | |
| if upload_time > seven_days_ago or doc_info.get("is_sample", False): | |
| documents_to_keep[doc_path] = doc_info | |
| else: | |
| # Actually delete from ChromaDB using source path filter | |
| try: | |
| self.vector_store._collection.delete(where={"source": doc_path}) | |
| deleted_count += 1 | |
| print(f"Deleted expired document: {doc_path}") | |
| except Exception as e: | |
| print(f"Error deleting document {doc_path}: {e}") | |
| # Update metadata file | |
| metadata["documents"] = documents_to_keep | |
| with open(self.doc_metadata_file, "w") as f: | |
| json.dump(metadata, f, indent=2) | |
| if deleted_count > 0: | |
| print(f"Cleanup complete: removed {deleted_count} expired documents") | |
| def get_documents_by_session(self, session_id: str) -> List[str]: | |
| """ | |
| Get list of document names for a given session. | |
| Args: | |
| session_id: User's session ID | |
| Returns: | |
| List[str]: List of document filenames belonging to this session | |
| """ | |
| if not self.doc_metadata_file.exists(): | |
| return [] | |
| with open(self.doc_metadata_file, "r") as f: | |
| metadata = json.load(f) | |
| documents = [] | |
| for doc_path, doc_info in metadata.get("documents", {}).items(): | |
| if doc_info.get("session_id") == session_id: | |
| # Extract just the filename | |
| filename = doc_path.split("/")[-1] if "/" in doc_path else doc_path | |
| documents.append( | |
| { | |
| "filename": filename, | |
| "path": doc_path, | |
| "uploaded_at": doc_info["uploaded_at"], | |
| } | |
| ) | |
| return documents | |
| def delete_document(self, session_id: str, source_path: str) -> bool: | |
| """ | |
| Delete a specific document from vector store and metadata. | |
| Args: | |
| session_id: User's session ID (for verification) | |
| source_path: Full path to the document to delete | |
| Returns: | |
| bool: True if deleted, False if not found or not authorized | |
| """ | |
| if not self.doc_metadata_file.exists(): | |
| return False | |
| with open(self.doc_metadata_file, "r") as f: | |
| metadata = json.load(f) | |
| # Verify document belongs to this session | |
| doc_info = metadata.get("documents", {}).get(source_path) | |
| if not doc_info: | |
| return False | |
| if doc_info.get("session_id") != session_id: | |
| return False # Not authorized to delete | |
| # Delete from ChromaDB | |
| try: | |
| self.vector_store._collection.delete(where={"source": source_path}) | |
| except Exception as e: | |
| print(f"Error deleting from ChromaDB: {e}") | |
| return False | |
| # Remove from metadata | |
| del metadata["documents"][source_path] | |
| with open(self.doc_metadata_file, "w") as f: | |
| json.dump(metadata, f, indent=2) | |
| return True | |