| import pinecone |
| import requests |
| import streamlit as st |
| import torch |
|
|
| from transformers import AutoTokenizer, AutoModel |
|
|
| from config import config |
| |
| |
| def search(text: str, k: int = 5): |
| """Get the k closest articles to the text.""" |
| embeds = _get_embeddings(text) |
| |
| r = requests.post( |
| f"https://{config.pinecone_index}-5b18b87.svc.{config.pinecone_env}.pinecone.io/query", |
| headers={ |
| "Api-Key": config.pinecone_api_key, |
| "accept": "application/json", |
| "content-type": "application/json", |
| }, |
| json={ |
| "vector": embeds, |
| "top_k": k, |
| "includeMetadata": True, |
| "includeValues": False, |
| }, |
| ) |
| |
| if r.status_code == 200: |
| return r.json() |
| else: |
| raise Exception(f"Error: {r.status_code} - {r.text}") |
|
|
| |
| def _get_embeddings(text: str): |
| inputs_ids = st.session_state.tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
| |
| with torch.no_grad(): |
| last_hidden_states = st.session_state.model(**inputs_ids)[0] |
| |
| return last_hidden_states.mean(dim=1).squeeze().tolist() |
|
|
|
|
|
|
| password = st.text_input("Password", type="password") |
| if password == config.password: |
| st.title("PubMed Embeddings") |
| st.subheader("Search for a PubMed article and get its id.") |
|
|
| text = st.text_input("Search for a PubMed article", "Epidemiology of COVID-19") |
|
|
| with st.spinner("Loading Embedding Model..."): |
| pinecone.init(api_key=config.pinecone_api_key, env=config.pinecone_env) |
| if "index" not in st.session_state: |
| st.session_state.index = pinecone.Index(config.pinecone_index) |
| if "tokenizer" not in st.session_state: |
| st.session_state.tokenizer = AutoTokenizer.from_pretrained(config.model_name) |
| if "model" not in st.session_state: |
| st.session_state.model = AutoModel.from_pretrained(config.model_name) |
|
|
| if st.button("Search"): |
| with st.spinner("Searching..."): |
| results = search(text) |
|
|
| for res in results["matches"]: |
| st.write(f"{res['id']} - confidence: {res['score']:.2f}") |
| else: |
| st.write("Password incorrect!") |
| |
|
|