| import os |
| import streamlit as st |
| import faiss |
| import pickle |
| from datasets import load_dataset |
| from sentence_transformers import SentenceTransformer |
| from groq import Groq |
|
|
| |
| DATASET_NAME = "neural-bridge/rag-dataset-1200" |
| MODEL_NAME = "all-MiniLM-L6-v2" |
| INDEX_FILE = "faiss_index.pkl" |
| DOCS_FILE = "contexts.pkl" |
|
|
| |
| client = Groq(api_key=os.environ.get("MY_KEY")) |
|
|
| |
| st.set_page_config(page_title="RAG App", layout="wide") |
| st.title("🧠 Retrieval-Augmented Generation (RAG) with Groq") |
|
|
| |
| @st.cache_resource |
| def setup_database(): |
| st.info("Setting up vector database...") |
| progress = st.progress(0) |
|
|
| |
| dataset = load_dataset(DATASET_NAME, split="train") |
| contexts = [entry["context"] for entry in dataset] |
| progress.progress(25) |
|
|
| |
| embedder = SentenceTransformer(MODEL_NAME) |
| embeddings = embedder.encode(contexts, show_progress_bar=True) |
| progress.progress(50) |
|
|
| |
| dimension = embeddings[0].shape[0] |
| faiss_index = faiss.IndexFlatL2(dimension) |
| faiss_index.add(embeddings) |
| progress.progress(75) |
|
|
| |
| with open(INDEX_FILE, "wb") as f: |
| pickle.dump(faiss_index, f) |
| with open(DOCS_FILE, "wb") as f: |
| pickle.dump(contexts, f) |
|
|
| progress.progress(100) |
| st.success("Database setup complete!") |
| return faiss_index, contexts |
|
|
| |
| if os.path.exists(INDEX_FILE) and os.path.exists(DOCS_FILE): |
| with open(INDEX_FILE, "rb") as f: |
| faiss_index = pickle.load(f) |
| with open(DOCS_FILE, "rb") as f: |
| all_contexts = pickle.load(f) |
| st.info("Loaded existing database.") |
| else: |
| faiss_index, all_contexts = setup_database() |
|
|
| |
| sample_questions = [ |
| "What is the purpose of the RAG dataset?", |
| "How does Falcon RefinedWeb contribute to this dataset?", |
| "What are the benefits of using retrieval-augmented generation?", |
| "Explain the structure of the RAG-1200 dataset.", |
| ] |
|
|
| st.subheader("Ask a question based on the dataset:") |
| question = st.text_input("Enter your question:", value=sample_questions[0]) |
|
|
| if st.button("Ask"): |
| if question.strip() == "": |
| st.warning("Please enter a question.") |
| else: |
| with st.spinner("Retrieving and generating answer..."): |
| |
| embedder = SentenceTransformer(MODEL_NAME) |
| query_embedding = embedder.encode([question]) |
| D, I = faiss_index.search(query_embedding, k=1) |
|
|
| |
| context = all_contexts[I[0][0]] |
| prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" |
|
|
| |
| response = client.chat.completions.create( |
| messages=[{"role": "user", "content": prompt}], |
| model="llama3-70b-8192" |
| ) |
|
|
| answer = response.choices[0].message.content |
| st.success("Answer:") |
| st.markdown(answer) |
|
|
| with st.expander("🔍 Retrieved Context"): |
| st.markdown(context) |
|
|