| import os |
| import sqlite3 |
| from dotenv import load_dotenv |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| from typing import List |
| from contextlib import asynccontextmanager |
|
|
|
|
| from langchain_core.documents import Document |
| from langchain_together import ChatTogether, TogetherEmbeddings |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain_core.documents import Document |
|
|
| import chromadb |
| from langchain_chroma import Chroma |
|
|
| |
| load_dotenv() |
|
|
| TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") |
| if not TOGETHER_API_KEY: |
| raise ValueError("TOGETHER_API_KEY environment variable not set. Please check your .env file.") |
|
|
| VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "/tmp/vector_db_chroma") |
| COLLECTION_NAME = "my_instrument_manual_chunks" |
|
|
| LLM_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" |
| EMBEDDINGS_MODEL_NAME = "intfloat/multilingual-e5-large-instruct" |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global rag_chain, retriever, prompt, llm |
|
|
| print("--- Initializing RAG components ---") |
| try: |
| llm = ChatTogether( |
| model=LLM_MODEL_NAME, |
| temperature=0.3, |
| api_key=TOGETHER_API_KEY |
| ) |
| print(f"LLM {LLM_MODEL_NAME} initialized.") |
|
|
| embeddings = TogetherEmbeddings( |
| model=EMBEDDINGS_MODEL_NAME, |
| api_key=TOGETHER_API_KEY |
| ) |
| client = chromadb.PersistentClient(path=VECTOR_DB_DIR) |
| vectorstore = Chroma( |
| persist_directory=VECTOR_DB_DIR, |
| collection_name=COLLECTION_NAME, |
| embedding_function=embeddings |
| ) |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) |
| print("Retriever initialized.") |
|
|
| answer_prompt = """ You are a professional HPLC instrument troubleshooting expert who specializes in helping junior researchers and students. |
| Your task is to answer the user's troubleshooting questions in detail and clearly based on the HPLC instrument knowledge provided below. |
| If there is no direct answer in the knowledge, please provide the most reasonable speculative suggestions based on your expert judgment, or ask further clarifying questions. |
| Please ensure that your answers are logically clear, easy to understand, and directly address the user's questions.""" |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", answer_prompt), |
| ("user", "context: {context}\n\nquestion: {question}"), |
| ]) |
|
|
| def format_docs(docs: List[Document]) -> str: |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
| rag_chain = ( |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} |
| | prompt |
| | llm |
| | StrOutputParser() |
| ) |
| print("RAG chain ready.") |
|
|
| except Exception as e: |
| raise RuntimeError(f"Failed to initialize RAG chain: {e}") |
|
|
| yield |
|
|
| |
| app = FastAPI( |
| title="LabAid AI", |
| description="API service for a Retrieval-Augmented Generation (RAG) AI assistant.", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| origins = [ |
| "http://localhost", |
| "http://localhost:3000", |
| "http://127.0.0.1:8000", |
| ] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| class QueryRequest(BaseModel): |
| query: str |
|
|
| class QueryResponse(BaseModel): |
| answer: str |
| source_documents: List[str] |
|
|
| |
| @app.post("/ask", response_model=QueryResponse) |
| async def ask_rag(request: QueryRequest): |
| if rag_chain is None or retriever is None: |
| raise HTTPException(status_code=500, detail="RAG chain not initialized.") |
|
|
| try: |
| user_query = request.query |
| print(f"Received query: {user_query}") |
|
|
| retrieved_docs = retriever.invoke(user_query) |
| formatted_context = "\n\n".join(doc.page_content for doc in retrieved_docs) |
|
|
| answer = (prompt | llm | StrOutputParser()).invoke({ |
| "context": formatted_context, |
| "question": user_query |
| }) |
|
|
| sources = [doc.page_content for doc in retrieved_docs] |
| return QueryResponse(answer=answer, source_documents=sources) |
|
|
| except Exception as e: |
| print(f"Error: {e}") |
| raise HTTPException(status_code=500, detail=f"Failed to process query: {e}") |
| |
| |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
|
|
| |
| app.mount("/static", StaticFiles(directory="build/static"), name="static") |
|
|
| |
| @app.get("/") |
| async def serve_react_app(): |
| return FileResponse("build/index.html") |
|
|
|
|