Spaces:
Running
Running
| import os | |
| import subprocess | |
| import json | |
| from fastapi import FastAPI, HTTPException, File, UploadFile | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_groq import ChatGroq | |
| from sentence_transformers import CrossEncoder | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.middleware import SlowAPIMiddleware | |
| from fastapi import Request | |
| load_dotenv() | |
| app = FastAPI(title="DocuMind Enterprise RAG API") | |
| # Setup Rate Limiter | |
| limiter = Limiter(key_func=get_remote_address) | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| app.add_middleware(SlowAPIMiddleware) | |
| CHROMA_DB_DIR = "vectorstore" | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| LLM_MODEL = "llama-3.1-8b-instant" | |
| embeddings = None | |
| vectorstore = None | |
| base_retriever = None | |
| cross_encoder = None | |
| llm = None | |
| def startup_event(): | |
| global embeddings, vectorstore, base_retriever, cross_encoder, llm | |
| print("Loading vector store & embedding model...") | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| if os.path.exists(CHROMA_DB_DIR): | |
| vectorstore = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings) | |
| base_retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| print("Initializing CrossEncoder ReRanker...") | |
| cross_encoder = CrossEncoder(RERANKER_MODEL) | |
| print("Initializing LLM via Groq...") | |
| if not os.environ.get("GROQ_API_KEY"): | |
| print("WARNING: GROQ_API_KEY not found in environment!") | |
| else: | |
| llm = ChatGroq(model_name=LLM_MODEL, temperature=0, streaming=True) | |
| class QueryRequest(BaseModel): | |
| question: str | |
| prompt_template = PromptTemplate.from_template("""You are a factual assistant for DocuMind. Answer ONLY using the context below. | |
| If the answer isn't in the context, say "I don't know." | |
| Context: {context} | |
| Question: {question}""") | |
| async def query_documents(request: Request, req: QueryRequest): | |
| if not base_retriever or not llm: | |
| raise HTTPException(status_code=500, detail="Backend not fully initialized (Vectorstore or LLM missing).") | |
| initial_docs = base_retriever.invoke(req.question) | |
| if not initial_docs: | |
| # Stream "I don't know." with empty sources | |
| async def empty_response(): | |
| yield json.dumps({"type": "sources", "data": []}) + "\n" | |
| yield json.dumps({"type": "token", "content": "I don't know."}) + "\n" | |
| return StreamingResponse(empty_response(), media_type="application/x-ndjson") | |
| pairs = [[req.question, doc.page_content] for doc in initial_docs] | |
| scores = cross_encoder.predict(pairs) | |
| for doc, score in zip(initial_docs, scores): | |
| doc.metadata['relevance_score'] = float(score) | |
| initial_docs.sort(key=lambda d: d.metadata['relevance_score'], reverse=True) | |
| top_docs = initial_docs[:3] | |
| context_text = "\n\n".join([doc.page_content for doc in top_docs]) | |
| chain = prompt_template | llm | |
| async def generate_response(): | |
| sources = [{"source": d.metadata.get("source", "Unknown"), "score": d.metadata.get("relevance_score"), "content": d.page_content} for d in top_docs] | |
| # Emit sources first | |
| yield json.dumps({"type": "sources", "data": sources}) + "\n" | |
| # Emit tokens | |
| async for chunk in chain.astream({"context": context_text, "question": req.question}): | |
| if chunk.content: | |
| yield json.dumps({"type": "token", "content": chunk.content}) + "\n" | |
| return StreamingResponse(generate_response(), media_type="application/x-ndjson") | |
| async def ingest_document(file: UploadFile = File(...)): | |
| os.makedirs("raw_documents", exist_ok=True) | |
| file_path = os.path.join("raw_documents", file.filename) | |
| with open(file_path, "wb") as f: | |
| f.write(await file.read()) | |
| # Run the ingestion script | |
| process = subprocess.run(["python", "ingest.py"], capture_output=True, text=True) | |
| # Reload vectorstore inline | |
| startup_event() | |
| if process.returncode != 0: | |
| raise HTTPException(status_code=500, detail=f"Ingestion failed: {process.stderr}") | |
| return {"message": f"Successfully ingested {file.filename}", "logs": process.stdout} | |
| async def get_sources(): | |
| docs = [] | |
| if os.path.exists("raw_documents"): | |
| docs = os.listdir("raw_documents") | |
| return {"documents": docs} | |