FetchMerck_AI / main.py
jeremygracey-ai's picture
Initial upload of main.py
4cccc3c verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from app_logic import load_llm_model, initialize_vector_db, get_rag_response
# Define the input data model
class QueryRequest(BaseModel):
question: str
# Initialize FastAPI app
app = FastAPI(title="Medical RAG API")
# Configuration paths
MODEL_PATH = "/root/.cache/huggingface/hub/models--TheBloke--Mistral-7B-Instruct-v0.1-GGUF/snapshots/731a9fc8f06f5f5e2db8a0cf9d256197eb6e05d1/mistral-7b-instruct-v0.1.Q4_K_M.gguf"
CHROMA_DIR = "./chroma_db"
# Global variables for the model and retriever
llm = None
retriever = None
@app.on_event("startup")
async def startup_event():
global llm, retriever
try:
print("Loading LLM and Vector Database...")
llm = load_llm_model(MODEL_PATH)
retriever = initialize_vector_db(CHROMA_DIR)
print("Startup complete.")
except Exception as e:
print(f"Error during startup: {e}")
@app.post("/query")
async def query_rag(request: QueryRequest):
if llm is None or retriever is None:
raise HTTPException(status_code=503, detail="Model not initialized")
try:
response = get_rag_response(request.question, llm, retriever)
return {"question": request.question, "answer": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)