| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from contextlib import asynccontextmanager |
| from langchain_community.document_loaders import PyPDFLoader |
| from langchain_community.document_loaders import WebBaseLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.vectorstores import FAISS |
| from langchain_openai import OpenAIEmbeddings |
| from langchain_community.embeddings import HuggingFaceEmbeddings |
| from langchain_openai import ChatOpenAI |
| from langchain_groq import ChatGroq |
| from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
| from langchain.chains.combine_documents import create_stuff_documents_chain |
| from langchain_community.chat_message_histories import ChatMessageHistory |
| from langchain_core.chat_history import BaseChatMessageHistory |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_core.runnables.history import RunnableWithMessageHistory |
| from transformers import pipeline |
| from bs4 import BeautifulSoup |
| from dotenv import load_dotenv |
| from PIL import Image |
| import base64 |
| import requests |
| import docx2txt |
| import pptx |
| import os |
| import utils |
|
|
| |
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| load_dotenv() |
| |
| |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" |
| os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") |
| |
| os.environ['GROQ_API_KEY'] = os.getenv("GROQ_API_KEY") |
| os.environ['HF_TOKEN'] = os.getenv("HF_TOKEN") |
| os.environ['NGROK_AUTHTOKEN'] = os.getenv("NGROK_AUTHTOKEN") |
| global image_to_text |
| image_to_text = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") |
| yield |
| |
| utils.unlink_images("/images") |
|
|
| |
| |
| app = FastAPI(lifespan=lifespan, docs_url="/") |
|
|
| |
| |
| class APIKey(BaseModel): |
| api_key: str |
|
|
| |
| class FileInfo(BaseModel): |
| file_path: str |
| file_type: str |
|
|
| |
| class Image(BaseModel): |
| image_path: str |
|
|
| |
| class Website(BaseModel): |
| website_link: str |
|
|
| |
| class Question(BaseModel): |
| question: str |
| resource: str |
|
|
| |
| |
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
| |
| def encode_image(image_path): |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
| |
| |
| @app.get("/") |
| async def welcome(): |
| return "Welcome to Brainbot!" |
|
|
| |
| @app.post("/set_api_key") |
| async def set_api_key(api_key: APIKey): |
| os.environ["OPENAI_API_KEY"] = api_key.api_key |
| return "API key set successfully!" |
|
|
| |
| |
| @app.post("/load_file/{llm}") |
| async def load_file(llm: str, file_info: FileInfo): |
| file_path = file_info.file_path |
| file_type = file_info.file_type |
| |
| |
| try: |
| |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
| |
| if file_type == "application/pdf": |
| |
| loader = PyPDFLoader(file_path) |
| docs = loader.load() |
| elif file_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document": |
| |
| text = docx2txt.process(file_path) |
| docs = text_splitter.create_documents([text]) |
| elif file_type == "text/plain": |
| |
| with open(file_path, 'r') as file: |
| text = file.read() |
| docs = text_splitter.create_documents([text]) |
| elif file_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation": |
| |
| presentation = pptx.Presentation(file_path) |
| |
| slide_texts = [] |
|
|
| |
| for slide in presentation.slides: |
| |
| slide_text = "" |
| |
| |
| for shape in slide.shapes: |
| if hasattr(shape, "text"): |
| slide_text += shape.text + "\n" |
| |
| slide_texts.append(slide_text.strip()) |
|
|
| docs = text_splitter.create_documents(slide_texts) |
| elif file_type == "text/html": |
| |
| with open(file_path, 'r') as file: |
| soup = BeautifulSoup(file, 'html.parser') |
| text = soup.get_text() |
| docs = text_splitter.create_documents([text]) |
|
|
| |
| os.unlink(file_path) |
|
|
| |
| documents = text_splitter.split_documents(docs) |
|
|
| if llm == "GPT-4": |
| embeddings = OpenAIEmbeddings() |
| elif llm == "GROQ": |
| embeddings = HuggingFaceEmbeddings() |
| |
| |
| global file_vectorstore |
| file_vectorstore = FAISS.from_documents(documents, embeddings) |
| except Exception as e: |
| |
| raise HTTPException(status_code=500, detail=str(e.with_traceback)) |
| return "File uploaded successfully!" |
|
|
| |
| |
| @app.post("/image/{llm}") |
| async def interpret_image(llm: str, image: Image): |
| try: |
| |
| base64_image = encode_image(image.image_path) |
| |
| if llm == "GPT-4": |
| headers = { |
| "Content-Type": "application/json", |
| "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" |
| } |
|
|
| payload = { |
| "model": "gpt-4-turbo", |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "text": "What's in this image?" |
| }, |
| { |
| "type": "image_url", |
| "image_url": { |
| "url": f"data:image/jpeg;base64,{base64_image}" |
| } |
| } |
| ] |
| } |
| ], |
| "max_tokens": 300 |
| } |
|
|
| response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) |
| response = response.json() |
| |
| description = response["choices"][0]["message"]["content"] |
| elif llm == "GROQ": |
| |
| response = image_to_text(image.image_path) |
| |
| description = response[0]["generated_text"] |
| chat = ChatGroq(temperature=0, groq_api_key=os.environ["GROQ_API_KEY"], model_name="Llama3-8b-8192") |
| system = "You are an assistant to understand and interpret images." |
| human = "{text}" |
| prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)]) |
|
|
| chain = prompt | chat |
| text = f"Explain the following image description in a small paragraph. {description}" |
| response = chain.invoke({"text": text}) |
| description = str.capitalize(description) + ". " + response.content |
| except Exception as e: |
| |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| return description |
|
|
| |
| |
| |
| @app.post("/load_link/{llm}") |
| async def website_info(llm: str, link: Website): |
| try: |
| |
| loader = WebBaseLoader(web_paths=(link.website_link,),) |
|
|
| global web_documents |
| web_documents = loader.load() |
|
|
| |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
| documents = text_splitter.split_documents(web_documents) |
|
|
| if llm == "GPT-4": |
| embeddings = OpenAIEmbeddings() |
| elif llm == "GROQ": |
| embeddings = HuggingFaceEmbeddings() |
|
|
| |
| global website_vectorstore |
| website_vectorstore = FAISS.from_documents(documents, embeddings) |
| except Exception as e: |
| |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| return "Website loaded successfully!" |
|
|
| |
| |
| @app.post("/answer_with_chat_history/{llm}") |
| async def get_answer_with_chat_history(llm: str, question: Question): |
| user_question = question.question |
| resource = question.resource |
| selected_llm = llm |
|
|
| try: |
| |
| if selected_llm == "GPT-4": |
| llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) |
| elif selected_llm == "GROQ": |
| llm = ChatGroq(groq_api_key=os.environ["GROQ_API_KEY"], model_name="Llama3-8b-8192") |
|
|
| |
| if resource == "file": |
| retriever = file_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) |
| elif resource == "web": |
| retriever = website_vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 5}) |
| |
| |
| contextualize_q_system_prompt = """Given a chat history and the latest user question \ |
| which might reference context in the chat history, formulate a standalone question \ |
| which can be understood without the chat history. Do NOT answer the question, \ |
| just reformulate it if needed and otherwise return it as is.""" |
| contextualize_q_prompt = ChatPromptTemplate.from_messages( |
| [ |
| ("system", contextualize_q_system_prompt), |
| MessagesPlaceholder("chat_history"), |
| ("human", "{input}"), |
| ] |
| ) |
| history_aware_retriever = create_history_aware_retriever( |
| llm, retriever, contextualize_q_prompt |
| ) |
|
|
| |
| qa_system_prompt = """You are an assistant for question-answering tasks. \ |
| Use the following pieces of retrieved context to answer the question. \ |
| If you don't know the answer, just say that you don't know. \ |
| Use three sentences maximum and keep the answer concise.\ |
| {context}""" |
| qa_prompt = ChatPromptTemplate.from_messages( |
| [ |
| ("system", qa_system_prompt), |
| MessagesPlaceholder("chat_history"), |
| ("human", "{input}"), |
| ] |
| ) |
| |
| question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) |
|
|
| rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) |
| |
| |
| store = {} |
| def get_session_history(session_id: str) -> BaseChatMessageHistory: |
| if session_id not in store: |
| store[session_id] = ChatMessageHistory() |
| return store[session_id] |
| |
| conversational_rag_chain = RunnableWithMessageHistory( |
| rag_chain, |
| get_session_history, |
| input_messages_key="input", |
| history_messages_key="chat_history", |
| output_messages_key="answer", |
| ) |
| |
| response = conversational_rag_chain.invoke( |
| {"input": user_question}, |
| config={ |
| "configurable": {"session_id": "abc123"} |
| }, |
| )["answer"] |
| except Exception as e: |
| |
| raise HTTPException(status_code=500, detail=str(e)) |
| |
| return response |