File size: 4,890 Bytes
f83e60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132a5cc
f83e60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import subprocess
import json
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from dotenv import load_dotenv

from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from sentence_transformers import CrossEncoder

from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from fastapi import Request

load_dotenv()

app = FastAPI(title="DocuMind Enterprise RAG API")

# Setup Rate Limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)

CHROMA_DB_DIR = "vectorstore"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
LLM_MODEL = "llama-3.1-8b-instant"

embeddings = None
vectorstore = None
base_retriever = None
cross_encoder = None
llm = None

@app.on_event("startup")
def startup_event():
    global embeddings, vectorstore, base_retriever, cross_encoder, llm
    
    print("Loading vector store & embedding model...")
    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
    if os.path.exists(CHROMA_DB_DIR):
        vectorstore = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings)
        base_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

    print("Initializing CrossEncoder ReRanker...")
    cross_encoder = CrossEncoder(RERANKER_MODEL)

    print("Initializing LLM via Groq...")
    if not os.environ.get("GROQ_API_KEY"):
        print("WARNING: GROQ_API_KEY not found in environment!")
    else:
        llm = ChatGroq(model_name=LLM_MODEL, temperature=0, streaming=True)

class QueryRequest(BaseModel):
    question: str

prompt_template = PromptTemplate.from_template("""You are a factual assistant for DocuMind. Answer ONLY using the context below.
If the answer isn't in the context, say "I don't know."
Context: {context}
Question: {question}""")

@app.post("/query")
@limiter.limit("5/minute")
async def query_documents(request: Request, req: QueryRequest):
    if not base_retriever or not llm:
        raise HTTPException(status_code=500, detail="Backend not fully initialized (Vectorstore or LLM missing).")
        
    initial_docs = base_retriever.invoke(req.question)
    
    if not initial_docs:
        # Stream "I don't know." with empty sources
        async def empty_response():
            yield json.dumps({"type": "sources", "data": []}) + "\n"
            yield json.dumps({"type": "token", "content": "I don't know."}) + "\n"
        return StreamingResponse(empty_response(), media_type="application/x-ndjson")
        
    pairs = [[req.question, doc.page_content] for doc in initial_docs]
    scores = cross_encoder.predict(pairs)
    for doc, score in zip(initial_docs, scores):
        doc.metadata['relevance_score'] = float(score)
        
    initial_docs.sort(key=lambda d: d.metadata['relevance_score'], reverse=True)
    top_docs = initial_docs[:3]
    
    context_text = "\n\n".join([doc.page_content for doc in top_docs])
    chain = prompt_template | llm
    
    async def generate_response():
        sources = [{"source": d.metadata.get("source", "Unknown"), "score": d.metadata.get("relevance_score"), "content": d.page_content} for d in top_docs]
        # Emit sources first
        yield json.dumps({"type": "sources", "data": sources}) + "\n"
        
        # Emit tokens
        async for chunk in chain.astream({"context": context_text, "question": req.question}):
            if chunk.content:
                yield json.dumps({"type": "token", "content": chunk.content}) + "\n"

    return StreamingResponse(generate_response(), media_type="application/x-ndjson")

@app.post("/ingest")
async def ingest_document(file: UploadFile = File(...)):
    os.makedirs("raw_documents", exist_ok=True)
    file_path = os.path.join("raw_documents", file.filename)
    with open(file_path, "wb") as f:
        f.write(await file.read())
        
    # Run the ingestion script
    process = subprocess.run(["python", "ingest.py"], capture_output=True, text=True)
    
    # Reload vectorstore inline
    startup_event()
    
    if process.returncode != 0:
        raise HTTPException(status_code=500, detail=f"Ingestion failed: {process.stderr}")
        
    return {"message": f"Successfully ingested {file.filename}", "logs": process.stdout}

@app.get("/sources")
async def get_sources():
    docs = []
    if os.path.exists("raw_documents"):
        docs = os.listdir("raw_documents")
    return {"documents": docs}