File size: 1,453 Bytes
4cccc3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from app_logic import load_llm_model, initialize_vector_db, get_rag_response

# Define the input data model
class QueryRequest(BaseModel):
    question: str

# Initialize FastAPI app
app = FastAPI(title="Medical RAG API")

# Configuration paths
MODEL_PATH = "/root/.cache/huggingface/hub/models--TheBloke--Mistral-7B-Instruct-v0.1-GGUF/snapshots/731a9fc8f06f5f5e2db8a0cf9d256197eb6e05d1/mistral-7b-instruct-v0.1.Q4_K_M.gguf"
CHROMA_DIR = "./chroma_db"

# Global variables for the model and retriever
llm = None
retriever = None

@app.on_event("startup")
async def startup_event():
    global llm, retriever
    try:
        print("Loading LLM and Vector Database...")
        llm = load_llm_model(MODEL_PATH)
        retriever = initialize_vector_db(CHROMA_DIR)
        print("Startup complete.")
    except Exception as e:
        print(f"Error during startup: {e}")

@app.post("/query")
async def query_rag(request: QueryRequest):
    if llm is None or retriever is None:
        raise HTTPException(status_code=503, detail="Model not initialized")
    
    try:
        response = get_rag_response(request.question, llm, retriever)
        return {"question": request.question, "answer": response}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)