| import streamlit as st |
| import PyPDF2 |
| import os |
| from sentence_transformers import SentenceTransformer |
| import faiss |
| import numpy as np |
| from transformers import pipeline |
|
|
| st.set_page_config(page_title="π PDF RAG QA", layout="wide") |
|
|
| |
| st.markdown(""" |
| <style> |
| .main {background-color: #f7faff;} |
| h1 {color: #4a4a8a;} |
| .stTextInput>div>div>input {border: 2px solid #d0d7ff;} |
| .stButton button {background-color: #4a4a8a; color: white;} |
| </style> |
| """, unsafe_allow_html=True) |
|
|
| st.title("π Ask Me Anything About Machine Learning") |
| st.caption("Using RAG (Retrieval-Augmented Generation) and a preloaded PDF") |
|
|
| |
| PDF_FILE = "ml_large_dataset.pdf" |
|
|
| def load_pdf(file_path): |
| with open(file_path, "rb") as f: |
| reader = PyPDF2.PdfReader(f) |
| return [page.extract_text() for page in reader.pages] |
|
|
| def chunk_text(pages, max_len=1000): |
| text = " ".join(pages) |
| words = text.split() |
| return [' '.join(words[i:i+max_len]) for i in range(0, len(words), max_len)] |
|
|
| @st.cache_resource |
| def setup_rag(): |
| pages = load_pdf(PDF_FILE) |
| chunks = chunk_text(pages) |
| model = SentenceTransformer('all-MiniLM-L6-v2') |
| embeddings = model.encode(chunks) |
| index = faiss.IndexFlatL2(embeddings.shape[1]) |
| index.add(np.array(embeddings)) |
| qa = pipeline("question-answering", model="deepset/roberta-base-squad2") |
| return chunks, model, index, qa |
|
|
| def retrieve_answer(question, chunks, model, index, qa_pipeline, k=6): |
| q_embed = model.encode([question]) |
| _, I = index.search(np.array(q_embed), k) |
| context = "\n\n".join([chunks[i] for i in I[0]]) |
| result = qa_pipeline(question=question, context=context) |
| return result['answer'] |
|
|
| chunks, embed_model, faiss_index, qa_model = setup_rag() |
|
|
| st.subheader("π¬ Ask a Question") |
| question = st.text_input("Enter your question:", placeholder="e.g., What is supervised learning?") |
|
|
| if question: |
| with st.spinner("π§ Searching for the answer..."): |
| answer = retrieve_answer(question, chunks, embed_model, faiss_index, qa_model) |
| st.markdown("#### π Answer:") |
| st.write(answer) |
|
|