ai-rag-document / app /rag_pipeline.py
pkgprateek's picture
feat: multi-document upload + streaming LLM responses
643f470
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.documents import Document
from langchain_core.runnables import (
RunnableParallel,
RunnablePassthrough,
RunnableLambda,
)
from typing import List
import os
from datetime import datetime, timedelta
import json
from pathlib import Path
# Fix tokenizer warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class RAGPipeline:
# Model configuration for multi-provider support
MODEL_CONFIG = {
"gpt-oss-120b": {
"provider": "groq",
"model": "openai/gpt-oss-120b",
"display": "GPT-OSS 120B (OpenAI)",
"temperature": 0.1,
"max_tokens": 1024,
},
"llama-3.3-70b": {
"provider": "groq",
"model": "llama-3.3-70b-versatile",
"display": "Llama 3.3 70B (Meta)",
"temperature": 0.1,
"max_tokens": 1024,
},
"gemma-3-27b": {
"provider": "openrouter",
"model": "google/gemma-3-27b-it:free",
"display": "Gemma 3 27B (Google)",
"temperature": 0.1,
"max_tokens": 512,
},
}
def __init__(
self,
persist_directory: str = "./data/chroma_db",
default_model: str = "gpt-oss-120b",
):
"""
Initialize RAG pipeline with embeddings, vector store, and multi-provider LLM support.
Sets up rate limiting (10 queries/hour) and supports Groq + OpenRouter APIs.
Args:
persist_directory: Path to store ChromaDB vector database (default: ./data/chroma_db)
default_model: Model key from MODEL_CONFIG (default: gpt-oss-120b)
"""
# Initialize better embeddings (BAAI/bge-small-en-v1.5)
self.embeddings = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True}, # Important for bge models
)
# Initialize vector store
self.vector_store = Chroma(
persist_directory=persist_directory,
embedding_function=self.embeddings,
)
# Rate limiting setup (10 queries per hour)
self.rate_limit_file = Path("./data/rate_limit.json")
self.rate_limit_file.parent.mkdir(parents=True, exist_ok=True)
# Document tracking for auto-cleanup (7-day retention)
self.doc_metadata_file = Path("./data/document_metadata.json")
self.doc_metadata_file.parent.mkdir(parents=True, exist_ok=True)
# Auto-cleanup on initialization
self._cleanup_old_documents()
# Initialize LLM with default model
self.current_model = default_model
self.llm = self._initialize_llm(default_model)
# Current session ID for retrieval filtering (set per-query)
self._current_session_id = None
# Create RAG chain
self.rag_chain = self.create_rag_chain()
def _initialize_llm(self, model_key: str):
"""
Initialize LLM based on provider and model configuration.
Supports both Groq and OpenRouter providers.
Args:
model_key: Key from MODEL_CONFIG dictionary
Returns:
ChatOpenAI: Configured LLM instance
Raises:
ValueError: If model_key is invalid or required API key is missing
"""
if model_key not in self.MODEL_CONFIG:
raise ValueError(
f"Invalid model key: {model_key}. "
f"Available models: {', '.join(self.MODEL_CONFIG.keys())}"
)
config = self.MODEL_CONFIG[model_key]
provider = config["provider"]
if provider == "groq":
# Groq API configuration
groq_key = os.getenv("GROQ_API_KEY")
if not groq_key:
raise ValueError(
"GROQ_API_KEY environment variable not set. "
"Get one free at https://console.groq.com/keys"
)
return ChatOpenAI(
model=config["model"],
openai_api_key=groq_key,
openai_api_base="https://api.groq.com/openai/v1",
temperature=config["temperature"],
max_tokens=config["max_tokens"],
)
elif provider == "openrouter":
# OpenRouter API configuration
openrouter_key = os.getenv("OPENROUTER_API_KEY")
if not openrouter_key:
raise ValueError(
"OPENROUTER_API_KEY environment variable not set. "
"Get one free at https://openrouter.ai/keys"
)
return ChatOpenAI(
model=config["model"],
openai_api_key=openrouter_key,
openai_api_base="https://openrouter.ai/api/v1",
temperature=config["temperature"],
max_tokens=config["max_tokens"],
)
else:
raise ValueError(f"Unknown provider: {provider}")
def switch_model(self, model_key: str) -> str:
"""
Dynamically switch to a different LLM model and recreate the RAG chain.
Args:
model_key: Key from MODEL_CONFIG dictionary
Returns:
str: Display name of the switched model
Raises:
ValueError: If model_key is invalid or API key is missing
"""
# Initialize new LLM
self.llm = self._initialize_llm(model_key)
self.current_model = model_key
# Recreate RAG chain with new LLM
self.rag_chain = self.create_rag_chain()
return self.MODEL_CONFIG[model_key]["display"]
def create_rag_chain(self):
"""
Creates the RAG chain by combining retriever, prompt template, and LLM.
Returns:
RunnableParallel: Chain that retrieves context and generates answers
"""
prompt = PromptTemplate(
input_variables=["context", "sources", "question"],
template="""You are an expert AI assistant specializing in document analysis. Your goal is to provide comprehensive, accurate, and well-cited answers.
Available Documents: {sources}
Context from Documents:
{context}
User Question: {question}
INSTRUCTIONS FOR YOUR RESPONSE:
1. **Analyze Thoroughly**: Read the context carefully and identify all relevant information
2. **Answer Comprehensively**: Provide a complete, detailed answer that fully addresses the question
3. **Use Proper Structure**:
- Start with a clear, direct answer
- Follow with supporting details and explanation
- Use markdown formatting (headings, bullet points, bold) for readability
4. **Cite Sources Inline**: As you make specific claims, cite the source immediately
- Format: (Source: filename, Page X) or (Source: filename) if page unknown
- Example: "The termination period is 30 days (Source: service_agreement.pdf, Page 3)"
- Be specific about which document and page number whenever possible
5. **Include a Sources Section**: At the end of your answer, add:
**Sources Referenced:**
• filename (Page X) - Brief note about what info came from here
• filename2 (Page Y) - Brief note
6. **Quality Standards**:
- Be specific and precise with facts, numbers, dates, and terms
- Quote exact phrases when important (use quotation marks)
- If information is unclear or missing, state what's uncertain
- Connect related points to create a cohesive narrative
Answer:""",
)
retriever = self.vector_store.as_retriever(
search_kwargs={"k": 4} # Retrieve top 4 most relevant chunks
)
# Wrap retriever to filter by session
def session_filter(docs):
"""Filter documents by current session."""
session_id = self._current_session_id
if session_id:
# Return docs matching session_id OR sample docs (is_sample=True)
return [
d
for d in docs
if d.metadata.get("session_id") == session_id
or d.metadata.get("is_sample", False)
]
return docs
# Create session-filtered retriever as a Runnable
session_filtered_retriever = retriever | RunnableLambda(session_filter)
rag_chain = RunnableParallel(
{
"result": (
{
"context": session_filtered_retriever
| (lambda docs: "\n\n".join([d.page_content for d in docs])),
"sources": session_filtered_retriever
| (
lambda docs: ", ".join(
list(
set(
[
d.metadata.get("source", "").split("/")[-1]
for d in docs
]
)
)
)
),
"question": RunnablePassthrough(),
}
| prompt
| self.llm
),
"source_documents": session_filtered_retriever,
}
)
return rag_chain
def add_documents(
self,
documents: List[Document],
session_id: str = None,
is_sample: bool = False,
) -> None:
"""
Add processed document chunks to the vector store for retrieval.
Adds session_id and timestamp metadata for isolation and auto-cleanup.
Args:
documents: List of Document objects with text and metadata
session_id: User's session ID for isolation (None for samples)
is_sample: If True, document is global and won't be auto-deleted
"""
# Add session and timestamp metadata to each chunk
now = datetime.now().isoformat()
for doc in documents:
doc.metadata["session_id"] = session_id if not is_sample else "global"
doc.metadata["uploaded_at"] = now
doc.metadata["is_sample"] = is_sample
self.vector_store.add_documents(documents)
# Track document metadata for cleanup (skip samples)
if not is_sample and documents:
self._track_document(
documents[0].metadata.get("source", "unknown"),
session_id=session_id,
)
def _check_rate_limit(self) -> bool:
"""
Enforces rate limit of 10 queries per hour by tracking query timestamps.
Returns:
bool: True if within limit, False if exceeded
"""
now = datetime.now()
# Load existing queries if file exists
if self.rate_limit_file.exists():
try:
with open(self.rate_limit_file, "r") as f:
content = f.read().strip()
if content: # Only parse if file is not empty
data = json.loads(content)
queries = [
datetime.fromisoformat(q) for q in data.get("queries", [])
]
else:
queries = []
except (json.JSONDecodeError, ValueError):
# If file is corrupted, start fresh
queries = []
else:
queries = []
# Remove queries older than 1 hour
one_hour_ago = now - timedelta(hours=1)
recent_queries = [q for q in queries if q > one_hour_ago]
# Check limit
if len(recent_queries) >= 10:
return False
# Add current query
recent_queries.append(now)
# Save updated queries
with open(self.rate_limit_file, "w") as f:
json.dump({"queries": [q.isoformat() for q in recent_queries]}, f)
return True
def query(self, question: str, session_id: str = None):
"""
Query the RAG system with a question, retrieves relevant context and generates answer.
Results are filtered to the user's session documents + global samples.
Args:
question: User's question string
session_id: User's session ID for filtering results
Returns:
dict: {
"answer": str,
"citations": List[dict],
"num_sources": int
}
Raises:
ValueError: If rate limit (10 queries/hour) is exceeded
"""
# Check rate limit
if not self._check_rate_limit():
raise ValueError(
"Rate limit exceeded. You can only ask 10 questions per hour. "
"Please try again later."
)
# Set session ID for filtered retrieval
self._current_session_id = session_id
answer = self.rag_chain.invoke(question)
result = answer["result"]
# Extract answer text
if hasattr(result, "content"):
answer_text = result.content
elif hasattr(result, "text"):
answer_text = result.text
else:
answer_text = str(result)
# Check if answer is empty
if not answer_text or answer_text.strip() == "":
answer_text = "I apologize, but I couldn't generate a response. Please try rephrasing your question."
return {"answer": answer_text}
def query_stream(self, question: str, session_id: str = None):
"""
Stream answer tokens for real-time display.
Yields tokens as they arrive from the LLM.
Args:
question: User's question string
session_id: User's session ID for filtering results
Yields:
str: Accumulated answer text (each yield contains full answer so far)
"""
# Check rate limit
if not self._check_rate_limit():
yield "⚠️ Rate limit exceeded. You can only ask 10 questions per hour. Please try again later."
return
# Set session ID for filtered retrieval
self._current_session_id = session_id
# Get documents using retriever (non-streaming part)
retriever = self.vector_store.as_retriever(search_kwargs={"k": 4})
docs = retriever.invoke(question)
# Filter by session
if session_id:
docs = [
d
for d in docs
if d.metadata.get("session_id") == session_id
or d.metadata.get("is_sample", False)
]
if not docs:
yield "I couldn't find relevant information in your documents. Please try rephrasing your question."
return
# Build context and sources
context = "\n\n".join([d.page_content for d in docs])
sources = ", ".join(
list(set([d.metadata.get("source", "").split("/")[-1] for d in docs]))
)
# Format prompt
prompt = self._format_prompt(context, sources, question)
# Stream from LLM
full_answer = ""
for chunk in self.llm.stream(prompt):
if hasattr(chunk, "content"):
full_answer += chunk.content
else:
full_answer += str(chunk)
yield full_answer
def _format_prompt(self, context: str, sources: str, question: str) -> str:
"""
Format the RAG prompt with context, sources, and question.
Args:
context: Retrieved document content
sources: Comma-separated source filenames
question: User's question
Returns:
str: Formatted prompt string
"""
return f"""You are an expert AI assistant specializing in document analysis. Your goal is to provide comprehensive, accurate, and well-cited answers.
Available Documents: {sources}
Context from Documents:
{context}
User Question: {question}
INSTRUCTIONS FOR YOUR RESPONSE:
1. **Analyze Thoroughly**: Read the context carefully and identify all relevant information
2. **Answer Comprehensively**: Provide a complete, detailed answer that fully addresses the question
3. **Use Proper Structure**:
- Start with a clear, direct answer
- Follow with supporting details and explanation
- Use markdown formatting (headings, bullet points, bold) for readability
4. **Cite Sources Inline**: As you make specific claims, cite the source immediately
- Format: (Source: filename, Page X) or (Source: filename) if page unknown
- Example: "The termination period is 30 days (Source: service_agreement.pdf, Page 3)"
- Be specific about which document and page number whenever possible
5. **Include a Sources Section**: At the end of your answer, add:
**Sources Referenced:**
• filename (Page X) - Brief note about what info came from here
• filename2 (Page Y) - Brief note
6. **Quality Standards**:
- Be specific and precise with facts, numbers, dates, and terms
- Quote exact phrases when important (use quotation marks)
- If information is unclear or missing, state what's uncertain
- Connect related points to create a cohesive narrative
Answer:"""
def _extract_citations(self, source_documents: List[Document]) -> List[dict]:
"""
Extract formatted citations from source documents with page numbers and previews.
Args:
source_documents: List of retrieved Document objects from RAG chain
Returns:
List[dict]: Formatted citations with id, source, page, and preview
"""
import re
citations = []
for idx, doc in enumerate(source_documents, 1):
# Extract file name (basename only)
source_path = doc.metadata.get("source", "Unknown")
file_name = (
source_path.split("/")[-1] if "/" in source_path else source_path
)
# Parse page number from content (PDF format: "---- Page X ----")
page_num = None
content = doc.page_content
# Try direct metadata first
if "page" in doc.metadata:
page_num = str(doc.metadata["page"])
# Fallback: parse from content markers
elif "---- Page " in content:
match = re.search(r"---- Page (\d+) ----", content)
if match:
page_num = match.group(1)
# Get clean preview (remove page markers)
preview = re.sub(r"---- Page \d+ ----", "", content).strip()
# Take first 150 chars for preview
if len(preview) > 150:
preview = preview[:150] + "..."
citations.append(
{
"id": idx,
"source": file_name,
"page": page_num,
"preview": preview,
"full_content": content,
}
)
return citations
def _track_document(self, source_path: str, session_id: str = None) -> None:
"""
Track document upload timestamp for auto-cleanup.
Args:
source_path: Path to the uploaded document
session_id: User's session ID for the document
"""
# Load existing metadata
if self.doc_metadata_file.exists():
with open(self.doc_metadata_file, "r") as f:
metadata = json.load(f)
else:
metadata = {"documents": {}}
# Add new document with current timestamp and session
metadata["documents"][source_path] = {
"uploaded_at": datetime.now().isoformat(),
"session_id": session_id,
"is_sample": False,
}
# Save updated metadata
with open(self.doc_metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
def _cleanup_old_documents(self) -> None:
"""
Remove documents older than 7 days from vector store.
Sample documents are never deleted.
"""
if not self.doc_metadata_file.exists():
return
with open(self.doc_metadata_file, "r") as f:
metadata = json.load(f)
now = datetime.now()
seven_days_ago = now - timedelta(days=7)
documents_to_keep = {}
deleted_count = 0
for doc_path, doc_info in metadata.get("documents", {}).items():
upload_time = datetime.fromisoformat(doc_info["uploaded_at"])
# Keep if uploaded within 7 days OR is a sample
if upload_time > seven_days_ago or doc_info.get("is_sample", False):
documents_to_keep[doc_path] = doc_info
else:
# Actually delete from ChromaDB using source path filter
try:
self.vector_store._collection.delete(where={"source": doc_path})
deleted_count += 1
print(f"Deleted expired document: {doc_path}")
except Exception as e:
print(f"Error deleting document {doc_path}: {e}")
# Update metadata file
metadata["documents"] = documents_to_keep
with open(self.doc_metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
if deleted_count > 0:
print(f"Cleanup complete: removed {deleted_count} expired documents")
def get_documents_by_session(self, session_id: str) -> List[str]:
"""
Get list of document names for a given session.
Args:
session_id: User's session ID
Returns:
List[str]: List of document filenames belonging to this session
"""
if not self.doc_metadata_file.exists():
return []
with open(self.doc_metadata_file, "r") as f:
metadata = json.load(f)
documents = []
for doc_path, doc_info in metadata.get("documents", {}).items():
if doc_info.get("session_id") == session_id:
# Extract just the filename
filename = doc_path.split("/")[-1] if "/" in doc_path else doc_path
documents.append(
{
"filename": filename,
"path": doc_path,
"uploaded_at": doc_info["uploaded_at"],
}
)
return documents
def delete_document(self, session_id: str, source_path: str) -> bool:
"""
Delete a specific document from vector store and metadata.
Args:
session_id: User's session ID (for verification)
source_path: Full path to the document to delete
Returns:
bool: True if deleted, False if not found or not authorized
"""
if not self.doc_metadata_file.exists():
return False
with open(self.doc_metadata_file, "r") as f:
metadata = json.load(f)
# Verify document belongs to this session
doc_info = metadata.get("documents", {}).get(source_path)
if not doc_info:
return False
if doc_info.get("session_id") != session_id:
return False # Not authorized to delete
# Delete from ChromaDB
try:
self.vector_store._collection.delete(where={"source": source_path})
except Exception as e:
print(f"Error deleting from ChromaDB: {e}")
return False
# Remove from metadata
del metadata["documents"][source_path]
with open(self.doc_metadata_file, "w") as f:
json.dump(metadata, f, indent=2)
return True