Spaces:
Running
Running
File size: 4,890 Bytes
f83e60c 132a5cc f83e60c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import os
import subprocess
import json
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from dotenv import load_dotenv
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from sentence_transformers import CrossEncoder
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from fastapi import Request
load_dotenv()
app = FastAPI(title="DocuMind Enterprise RAG API")
# Setup Rate Limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
CHROMA_DB_DIR = "vectorstore"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
LLM_MODEL = "llama-3.1-8b-instant"
embeddings = None
vectorstore = None
base_retriever = None
cross_encoder = None
llm = None
@app.on_event("startup")
def startup_event():
global embeddings, vectorstore, base_retriever, cross_encoder, llm
print("Loading vector store & embedding model...")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
if os.path.exists(CHROMA_DB_DIR):
vectorstore = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings)
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
print("Initializing CrossEncoder ReRanker...")
cross_encoder = CrossEncoder(RERANKER_MODEL)
print("Initializing LLM via Groq...")
if not os.environ.get("GROQ_API_KEY"):
print("WARNING: GROQ_API_KEY not found in environment!")
else:
llm = ChatGroq(model_name=LLM_MODEL, temperature=0, streaming=True)
class QueryRequest(BaseModel):
question: str
prompt_template = PromptTemplate.from_template("""You are a factual assistant for DocuMind. Answer ONLY using the context below.
If the answer isn't in the context, say "I don't know."
Context: {context}
Question: {question}""")
@app.post("/query")
@limiter.limit("5/minute")
async def query_documents(request: Request, req: QueryRequest):
if not base_retriever or not llm:
raise HTTPException(status_code=500, detail="Backend not fully initialized (Vectorstore or LLM missing).")
initial_docs = base_retriever.invoke(req.question)
if not initial_docs:
# Stream "I don't know." with empty sources
async def empty_response():
yield json.dumps({"type": "sources", "data": []}) + "\n"
yield json.dumps({"type": "token", "content": "I don't know."}) + "\n"
return StreamingResponse(empty_response(), media_type="application/x-ndjson")
pairs = [[req.question, doc.page_content] for doc in initial_docs]
scores = cross_encoder.predict(pairs)
for doc, score in zip(initial_docs, scores):
doc.metadata['relevance_score'] = float(score)
initial_docs.sort(key=lambda d: d.metadata['relevance_score'], reverse=True)
top_docs = initial_docs[:3]
context_text = "\n\n".join([doc.page_content for doc in top_docs])
chain = prompt_template | llm
async def generate_response():
sources = [{"source": d.metadata.get("source", "Unknown"), "score": d.metadata.get("relevance_score"), "content": d.page_content} for d in top_docs]
# Emit sources first
yield json.dumps({"type": "sources", "data": sources}) + "\n"
# Emit tokens
async for chunk in chain.astream({"context": context_text, "question": req.question}):
if chunk.content:
yield json.dumps({"type": "token", "content": chunk.content}) + "\n"
return StreamingResponse(generate_response(), media_type="application/x-ndjson")
@app.post("/ingest")
async def ingest_document(file: UploadFile = File(...)):
os.makedirs("raw_documents", exist_ok=True)
file_path = os.path.join("raw_documents", file.filename)
with open(file_path, "wb") as f:
f.write(await file.read())
# Run the ingestion script
process = subprocess.run(["python", "ingest.py"], capture_output=True, text=True)
# Reload vectorstore inline
startup_event()
if process.returncode != 0:
raise HTTPException(status_code=500, detail=f"Ingestion failed: {process.stderr}")
return {"message": f"Successfully ingested {file.filename}", "logs": process.stdout}
@app.get("/sources")
async def get_sources():
docs = []
if os.path.exists("raw_documents"):
docs = os.listdir("raw_documents")
return {"documents": docs}
|