| import faiss |
| import pickle |
| import datasets |
| import numpy as np |
| import requests |
| import streamlit as st |
| from vector_engine.utils import vector_search |
| from transformers import AutoModel, AutoTokenizer |
|
|
| from datasets import load_dataset |
|
|
| |
| @st.cache_data |
| def read_data(dataset_repo='dhmeltzer/ELI5_embedded'): |
| """Read the data from huggingface.""" |
| return load_dataset(dataset_repo)['train'] |
|
|
| @st.cache_data |
| def load_faiss_index(path_to_faiss="./faiss_index.pickle"): |
| """Load and deserialize the Faiss index.""" |
| with open(path_to_faiss, "rb") as h: |
| data = pickle.load(h) |
| return faiss.deserialize_index(data) |
|
|
| def main(): |
| |
| data = read_data() |
| faiss_index = load_faiss_index() |
|
|
| model_id="sentence-transformers/all-MiniLM-L6-v2" |
| |
| api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" |
| headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"} |
|
|
| def query(texts): |
| response = requests.post(api_url, headers=headers, json={"inputs": texts, "options":{"wait_for_model":True}}) |
| return response.json() |
|
|
|
|
| st.title("Semantic Search for Questions on Reddit.") |
|
|
| st.write("""This application lets you perform a semantic search through questions in the r/ELI5 [dataset](https://huggingface.co/datasets/eli5). \ |
| The questions and user input are encoded into a high-dimensional vectors space using a Sentence-Transformer model, and in particular the checkpoint [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"). |
| To perform the search we use FAISS, which performs an efficient similarity search through the (vectorized) questions. |
| The ELI5 dataset contains posts from three subreddits, AskScience (asks), AskHistorians (askh), and ExplainLikeImFive (eli5). |
| The score corresponds to the rating each answer recieved when posted on Reddit. |
| \n You can use the slider on the left to change the number of results shown. |
| We unfortunately cannot verify the veracity of any of the answers posted!""") |
|
|
| |
| |
| user_input = st.text_area("Search box", "What is spacetime made out of?") |
|
|
| |
| st.sidebar.markdown("**Filters**") |
| num_results = st.sidebar.slider("Number of search results", 1, 50, 5) |
|
|
| vector = query([user_input]) |
| |
| if user_input: |
| |
| _, I = faiss_index.search(np.array(vector).astype("float32"), k=num_results) |
| |
| |
| for id_ in I.flatten().tolist(): |
| row = data[id_] |
| |
| answers=row['answers']['text'] |
| answers_URLs = row['answers_urls']['url'] |
| for k in range(len(answers_URLs)): |
| answers = [answer.replace(f'_URL_{k}_',answers_URLs[k]) for answer in answers] |
| |
| |
| st.write(f"**Title**: {row['title']}") |
| st.write(f"**Split**: {row['split']}") |
| st.write(f"**Score**: {row['answers']['score'][0]}") |
| st.write(f"**Top Answer**: {answers[0]}") |
| st.write("-"*20) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| main() |