| import os |
| import streamlit as st |
| from PIL import Image, ImageOps |
| from langchain_openai import ChatOpenAI |
| from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings |
| from langchain.vectorstores import FAISS |
| from langchain.chains import RetrievalQA |
| from langchain import PromptTemplate |
| from langchain.retrievers import ContextualCompressionRetriever |
| from langchain.retrievers.document_compressors import FlashrankRerank |
| from dotenv import load_dotenv |
| from langchain_community.embeddings.bedrock import BedrockEmbeddings |
| load_dotenv() |
| |
| PDF_CHUNK_SIZE = 1024 |
| PDF_CHUNK_OVERLAP = 256 |
| k = 9 |
|
|
| |
| def load_and_pad_image(image_path, size=(64, 64)): |
| img = Image.open(image_path) |
| return ImageOps.pad(img, size) |
|
|
| favicon_path = "medical.png" |
| favicon_image = load_and_pad_image(favicon_path) |
|
|
| |
| st.set_page_config( |
| page_title="Chatbot", |
| page_icon=favicon_image, |
| ) |
|
|
| |
| col1, col2 = st.columns([1, 8]) |
| with col1: |
| st.image(favicon_image) |
| with col2: |
| st.markdown( |
| """ |
| <h1 style='text-align: left; margin-top: -12px;'>Chatbot</h1> |
| """, unsafe_allow_html=True |
| ) |
|
|
| |
| model_options = ["gpt-4o", "gpt-4o-mini"] |
| selected_model = st.selectbox("Choose a GPT model", model_options) |
|
|
| embedding_model_options = ["OpenAI"] |
| selected_embedding_model = st.selectbox("Choose an Embedding model", embedding_model_options) |
|
|
| |
| def get_llm(selected_model): |
| api_key = os.getenv("DeepSeek_API_KEY") if selected_model == "deepseek-chat" else os.getenv("OPENAI_API_KEY") |
| return ChatOpenAI( |
| model=selected_model, |
| temperature=0, |
| max_tokens=None, |
| api_key=api_key, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @st.cache_resource |
| def load_vector_store(selected_embedding_model): |
| if selected_embedding_model == "OpenAI": |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=os.getenv("OPENAI_API_KEY")) |
| return FAISS.load_local("faiss_table", embeddings, allow_dangerous_deserialization=True) |
| else: |
| embeddings = HuggingFaceEmbeddings(model_name="abhinand/MedEmbed-large-v0.1") |
| return FAISS.load_local("faiss_index_medical_MedEmbed", embeddings, allow_dangerous_deserialization=True) |
| |
| |
| vector_store = load_vector_store(selected_embedding_model) |
| llm = get_llm(selected_model) |
|
|
| |
| def main(): |
| st.session_state['knowledge_base'] = vector_store |
| st.header("Ask a Question") |
|
|
| question = st.text_input("Enter your question") |
| if st.button("Get Answer"): |
| knowledge_base = st.session_state['knowledge_base'] |
| retriever = knowledge_base.as_retriever(search_kwargs={"k": k}) |
| compressor = FlashrankRerank() |
| compression_retriever = ContextualCompressionRetriever( |
| base_compressor=compressor, base_retriever=retriever |
| ) |
|
|
| system_prompt = """ |
| You are a friendly and knowledgeable assistant who is an expert in medical education who will only answer from the context provided. You need to understand the best context to answer the question. |
| """ |
|
|
| template = f""" |
| {system_prompt} |
| ------------------------------- |
| Context: {{context}} |
| Question: {{question}} |
| Answer: |
| """ |
|
|
| prompt = PromptTemplate( |
| template=template, |
| input_variables=['context', 'question'] |
| ) |
|
|
| qa_chain = RetrievalQA.from_chain_type( |
| llm, |
| retriever=compression_retriever, |
| return_source_documents=True, |
| chain_type_kwargs={"prompt": prompt} |
| ) |
|
|
| response = qa_chain.invoke({"query": question}) |
| st.write(f"**Answer:** {response['result']}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|