DocuMind_hf / main.py
ktejeshnaidu's picture
Update main.py
132a5cc verified
import os
import subprocess
import json
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from sentence_transformers import CrossEncoder
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from fastapi import Request
load_dotenv()
app = FastAPI(title="DocuMind Enterprise RAG API")
# Setup Rate Limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
CHROMA_DB_DIR = "vectorstore"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
LLM_MODEL = "llama-3.1-8b-instant"
embeddings = None
vectorstore = None
base_retriever = None
cross_encoder = None
llm = None
@app.on_event("startup")
def startup_event():
global embeddings, vectorstore, base_retriever, cross_encoder, llm
print("Loading vector store & embedding model...")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
if os.path.exists(CHROMA_DB_DIR):
vectorstore = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings)
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
print("Initializing CrossEncoder ReRanker...")
cross_encoder = CrossEncoder(RERANKER_MODEL)
print("Initializing LLM via Groq...")
if not os.environ.get("GROQ_API_KEY"):
print("WARNING: GROQ_API_KEY not found in environment!")
else:
llm = ChatGroq(model_name=LLM_MODEL, temperature=0, streaming=True)
class QueryRequest(BaseModel):
question: str
prompt_template = PromptTemplate.from_template("""You are a factual assistant for DocuMind. Answer ONLY using the context below.
If the answer isn't in the context, say "I don't know."
Context: {context}
Question: {question}""")
@app.post("/query")
@limiter.limit("5/minute")
async def query_documents(request: Request, req: QueryRequest):
if not base_retriever or not llm:
raise HTTPException(status_code=500, detail="Backend not fully initialized (Vectorstore or LLM missing).")
initial_docs = base_retriever.invoke(req.question)
if not initial_docs:
# Stream "I don't know." with empty sources
async def empty_response():
yield json.dumps({"type": "sources", "data": []}) + "\n"
yield json.dumps({"type": "token", "content": "I don't know."}) + "\n"
return StreamingResponse(empty_response(), media_type="application/x-ndjson")
pairs = [[req.question, doc.page_content] for doc in initial_docs]
scores = cross_encoder.predict(pairs)
for doc, score in zip(initial_docs, scores):
doc.metadata['relevance_score'] = float(score)
initial_docs.sort(key=lambda d: d.metadata['relevance_score'], reverse=True)
top_docs = initial_docs[:3]
context_text = "\n\n".join([doc.page_content for doc in top_docs])
chain = prompt_template | llm
async def generate_response():
sources = [{"source": d.metadata.get("source", "Unknown"), "score": d.metadata.get("relevance_score"), "content": d.page_content} for d in top_docs]
# Emit sources first
yield json.dumps({"type": "sources", "data": sources}) + "\n"
# Emit tokens
async for chunk in chain.astream({"context": context_text, "question": req.question}):
if chunk.content:
yield json.dumps({"type": "token", "content": chunk.content}) + "\n"
return StreamingResponse(generate_response(), media_type="application/x-ndjson")
@app.post("/ingest")
async def ingest_document(file: UploadFile = File(...)):
os.makedirs("raw_documents", exist_ok=True)
file_path = os.path.join("raw_documents", file.filename)
with open(file_path, "wb") as f:
f.write(await file.read())
# Run the ingestion script
process = subprocess.run(["python", "ingest.py"], capture_output=True, text=True)
# Reload vectorstore inline
startup_event()
if process.returncode != 0:
raise HTTPException(status_code=500, detail=f"Ingestion failed: {process.stderr}")
return {"message": f"Successfully ingested {file.filename}", "logs": process.stdout}
@app.get("/sources")
async def get_sources():
docs = []
if os.path.exists("raw_documents"):
docs = os.listdir("raw_documents")
return {"documents": docs}