| import re |
| import gradio as gr |
| from scipy.sparse import load_npz |
| import torch |
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.preprocessing import normalize |
| from transformers import BertTokenizer, BertModel |
| import numpy as np |
| import pandas as pd |
| from datasets import load_dataset |
| from gensim.models import KeyedVectors |
| import plotly.graph_objects as go |
| from sklearn.decomposition import PCA |
| from transformers import AutoTokenizer, AutoModel |
| from sentence_transformers import CrossEncoder |
| from sentence_transformers import SentenceTransformer |
|
|
| class ArxivSearch: |
| def __init__(self, dataset, embedding="sbert"): |
| self.dataset = dataset |
| self.embedding = embedding |
| self.query = None |
| self.documents = [] |
| self.titles = [] |
| self.raw_texts = [] |
| self.arxiv_ids = [] |
| self.last_results = [] |
| self.query_encoding = None |
|
|
| |
| self.embedding_dropdown = gr.Dropdown( |
| choices=["tfidf", "word2vec", "bert", "sbert", "clustered sbert"], |
| value="sbert", |
| label="Model" |
| ) |
| |
| self.plot_button = gr.Button("Show 3D Plot") |
|
|
| |
| with gr.Blocks() as self.iface: |
| gr.Markdown("# arXiv Search Engine") |
| gr.Markdown("Search arXiv papers by keyword and embedding model.") |
|
|
| self.plot_output = gr.Plot() |
| |
| with gr.Row(): |
| self.query_box = gr.Textbox(lines=1, placeholder="Enter your search query", label="Query") |
| self.embedding_dropdown.render() |
| self.plot_button.render() |
| with gr.Column(): |
| self.search_button = gr.Button("Search") |
|
|
| self.output_md = gr.Markdown() |
|
|
| self.query_box.submit( |
| self.search_function, |
| inputs=[self.query_box, self.embedding_dropdown], |
| outputs=self.output_md |
| ) |
| |
| |
| |
| |
| |
| self.embedding_dropdown.change( |
| self.search_function, |
| inputs=[self.query_box, self.embedding_dropdown], |
| outputs=self.output_md |
| ) |
| self.plot_button.click( |
| self.plot_3d_embeddings, |
| inputs=[], |
| outputs=self.plot_output |
| ) |
| self.search_button.click( |
| self.search_function, |
| inputs=[self.query_box, self.embedding_dropdown], |
| outputs=self.output_md |
| ) |
|
|
| self.load_data(dataset) |
| |
| self.load_model('tfidf') |
| self.load_model('word2vec') |
| self.load_model('bert') |
| |
| |
| self.load_model('clustered sbert') |
|
|
| self.iface.launch() |
|
|
| def load_data(self, dataset): |
| train_data = dataset["train"] |
| for item in train_data.select(range(len(train_data))): |
| text = item["text"] |
| if not text or len(text.strip()) < 10: |
| continue |
|
|
| lines = text.splitlines() |
| title_lines = [] |
| found_arxiv = False |
| arxiv_id = None |
|
|
| for line in lines: |
| line_strip = line.strip() |
| if not found_arxiv and line_strip.lower().startswith("arxiv:"): |
| found_arxiv = True |
| match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE) |
| if match: |
| arxiv_id = match.group(0).lower() |
| elif not found_arxiv: |
| title_lines.append(line_strip) |
| else: |
| if line_strip.lower().startswith("abstract"): |
| break |
|
|
| title = " ".join(title_lines).strip() |
|
|
| self.raw_texts.append(text.strip()) |
| self.titles.append(title) |
| self.documents.append(text.strip()) |
| self.arxiv_ids.append(arxiv_id) |
|
|
| def plot_dense(self, embedding, pca, results_indices): |
| all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0])))) |
| all_data = embedding[all_indices] |
| pca.fit(all_data) |
| reduced_data = pca.transform(embedding[:5000]) |
| reduced_results_points = pca.transform(embedding[results_indices]) if len(results_indices) > 0 else np.empty((0, 3)) |
| query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3)) |
| return reduced_data, reduced_results_points, query_point |
|
|
| def plot_3d_embeddings(self): |
| |
| pca = PCA(n_components=3) |
| results_indices = [i[0] for i in self.last_results] |
| |
| if self.embedding == "tfidf": |
| all_indices = list(set(results_indices) | set(range(min(5000, self.tfidf_matrix.shape[0])))) |
| all_data = self.tfidf_matrix[all_indices].toarray() |
| pca.fit(all_data) |
| reduced_data = pca.transform(self.tfidf_matrix[:5000].toarray()) |
| reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3)) |
| elif self.embedding == "word2vec": |
| reduced_data, reduced_results_points, query_point = self.plot_dense(self.word2vec_embeddings, pca, results_indices) |
| elif self.embedding == "bert": |
| reduced_data, reduced_results_points, query_point = self.plot_dense(self.bert_embeddings, pca, results_indices) |
| elif self.embedding == "sbert" or self.embedding == "clustered sbert": |
| reduced_data, reduced_results_points, query_point = self.plot_dense(self.sbert_embedding, pca, results_indices) |
| if self.embedding == "clustered sbert": |
| cluster_colors = ["#00b7ff" if i in np.where(self.clusters == self.top_cluster_index)[0] else "#ffffff" for i in range(len(self.documents))] |
| |
| |
| else: |
| raise ValueError(f"Unsupported embedding type: {self.embedding}") |
| |
| results_scores = [i[1] for i in self.last_results] |
|
|
| traces = [] |
|
|
| trace = go.Scatter3d( |
| x=reduced_data[:, 0], |
| y=reduced_data[:, 1], |
| z=reduced_data[:, 2], |
| mode='markers', |
| marker=dict(size=3.5, |
| color="#ffffff" if self.embedding != "clustered sbert" else cluster_colors, |
| opacity=0.2), |
| name='All Documents', |
| text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))], |
| hoverinfo='text' |
| ) |
|
|
| traces.append(trace) |
|
|
| layout = go.Layout( |
| margin=dict(l=0, r=0, b=0, t=0), |
| scene=dict( |
| xaxis_title='PCA 1', |
| yaxis_title='PCA 2', |
| zaxis_title='PCA 3', |
| xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
| yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
| zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'), |
| ), |
| paper_bgcolor='black', |
| plot_bgcolor='black', |
| font=dict(color='white'), |
| legend=dict( |
| bgcolor='rgba(0,0,0,0)', |
| bordercolor='rgba(0,0,0,0)', |
| x=0.01, |
| y=0.99, |
| xanchor='left', |
| yanchor='top' |
| ) |
| ) |
|
|
| if len(reduced_results_points) > 0: |
| custom_colorscale = [ |
| [0.0, "#00ffea"], |
| [1.0, "#ffea00"], |
| ] |
|
|
| results_trace = go.Scatter3d( |
| x=reduced_results_points[:, 0], |
| y=reduced_results_points[:, 1], |
| z=reduced_results_points[:, 2], |
| mode='markers', |
| marker=dict(size=4.25, |
| color=results_scores, |
| colorscale=custom_colorscale, |
| opacity=0.99, |
| colorbar=dict( |
| title="Score", |
| bgcolor='rgba(0,0,0,0)', |
| bordercolor='rgba(0,0,0,0)' |
|
|
| ) |
| ), |
| name='Results', |
| text=[f"<br>{self.documents[i][:100]}" for i in results_indices], |
| hoverinfo='text' |
| ) |
|
|
| traces.append(results_trace) |
|
|
| if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0: |
| query_trace = go.Scatter3d( |
| x=query_point[:, 0], |
| y=query_point[:, 1], |
| z=query_point[:, 2], |
| mode='markers', |
| marker=dict(size=5, color='red', opacity=0.8), |
| name='Query', |
| text=[f"<br>Query: {self.query}"], |
| hoverinfo='text' |
| ) |
| traces.append(query_trace) |
|
|
| fig = go.Figure(data=traces, layout=layout) |
|
|
| return fig |
| |
| def keyword_match_ranking(self, query, top_n=10): |
| query_terms = query.lower().split() |
| query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms] |
| if not query_indices: |
| return [] |
| scores = [] |
| for doc_idx in range(self.tfidf_matrix.shape[0]): |
| doc_vector = self.tfidf_matrix[doc_idx] |
| doc_score = sum(doc_vector[0, i] for i in query_indices) |
| if doc_score > 0: |
| scores.append((doc_idx, doc_score)) |
| scores.sort(key=lambda x: x[1], reverse=True) |
| return scores[:top_n] |
| |
| def word2vec_search(self, query, top_n=10): |
| tokens = [word for word in query.split() if word in self.wv_model.key_to_index] |
| if not tokens: |
| return [] |
| vectors = np.array([self.wv_model[word] for word in tokens]) |
| query_vec = np.mean(vectors, axis=0).reshape(1, -1) |
| self.query_encoding = query_vec |
| sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten() |
| top_indices = sims.argsort()[::-1][:top_n] |
| return [(i, sims[i]) for i in top_indices] |
|
|
| def bert_search(self, query, top_n=10): |
| with torch.no_grad(): |
| inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length') |
| outputs = self.model(**inputs) |
| query_vec = outputs.last_hidden_state[:, 0, :].numpy() |
|
|
| self.query_encoding = query_vec |
| sims = cosine_similarity(query_vec, self.bert_embeddings).flatten() |
| top_indices = sims.argsort()[::-1][:top_n] |
| print(f"sim, top_indices: {sims}, {top_indices}") |
| return [(i, sims[i]) for i in top_indices] |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| def sbert_search(self, query, top_n=10): |
| query_vec = self.sbert_model.encode([query]) |
| self.query_encoding = query_vec |
| cos_scores = cosine_similarity(query_vec, self.sbert_embedding)[0] |
| top_k_indices = np.argsort(cos_scores)[-50:][::-1] |
| candidates = [dataset['train'][int(i)]['text'] for i in top_k_indices] |
| scores = self.cross_encoder.predict([(query, doc) for doc in candidates]) |
| final_scores = 0.7 * scores + 0.3 * cos_scores[top_k_indices] |
| top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]] |
| print(f"sim, top_indices: {final_scores}, {top_indices}") |
| return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]] |
|
|
| def clustered_sbert_search(self, query, top_n=10): |
| query_vec = self.sbert_model.encode([query]) |
| self.query_encoding = query_vec |
| cos_cluster_scores = cosine_similarity(query_vec, self.cluster_centers)[0] |
| self.top_cluster_index = np.argmax(cos_cluster_scores) |
| cos_scores = cosine_similarity(query_vec, self.clustered_embeddings[self.top_cluster_index])[0] |
| top_k_indices = np.argsort(cos_scores)[-50:][::-1] |
| top_full_dataset_indices = np.where(self.clusters == self.top_cluster_index)[0][top_k_indices] |
| candidates = [self.dataset['train'][int(i)]['text'] for i in top_full_dataset_indices] |
| scores = self.cross_encoder.predict([(query, doc) for doc in candidates]) |
| final_scores = 0.7 * scores + 0.3 * cos_scores[top_k_indices] |
| top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]] |
| top_indices_full = np.where(self.clusters == self.top_cluster_index)[0][top_indices] |
| print(f"sim, top_indices: {final_scores}, {top_indices}") |
| return [(i, final_scores[j]) for j, i in enumerate(top_indices_full)] |
|
|
| def model_switch(self, embedding, progress=gr.Progress()): |
| if self.embedding != embedding: |
| old_embedding = self.embedding |
| print(f"Switching model to {embedding}") |
| self.load_model(embedding) |
| print(f"Loaded {embedding} model") |
| self.embedding = embedding |
| if old_embedding == "tfidf": |
| del self.tfidf_matrix |
| del self.feature_names |
| if old_embedding == "word2vec": |
| del self.word2vec_embeddings |
| del self.wv_model |
| if old_embedding == "bert": |
| del self.bert_embeddings |
| del self.tokenizer |
| del self.model |
| if old_embedding == "scibert": |
| del self.scibert_embeddings |
| del self.sci_tokenizer |
| del self.sci_model |
| if old_embedding == "sbert": |
| del self.sbert_model |
| del self.sbert_embedding |
| del self.cross_encoder |
| print(f"old embedding removed") |
| if hasattr(self, "query") and self.query: |
| return self.search_function(self.query, self.embedding) |
| else: |
| return "" |
| return gr.update() |
|
|
| def load_model(self, embedding): |
| self.embedding = embedding |
| if self.embedding == "tfidf": |
| self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz") |
| with open("TF-IDF embeddings/feature_names.txt", "r") as f: |
| self.feature_names = [line.strip() for line in f.readlines()] |
| elif self.embedding == "word2vec": |
| |
| self.word2vec_embeddings = np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"] |
| self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model") |
| elif self.embedding == "bert": |
| self.bert_embeddings = np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"] |
| self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| self.model = BertModel.from_pretrained('bert-base-uncased') |
| self.model.eval() |
| |
| |
| |
| |
| |
| elif self.embedding == "sbert" or self.embedding == "clustered sbert": |
| self.sbert_model = SentenceTransformer("all-MiniLM-L6-v2") |
| self.sbert_embedding = np.load("BERT embeddings/sbert_embedding.npz")["sbert_embedding"] |
| |
| self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-TinyBERT-L-2-v2") |
| if self.embedding == "clustered sbert": |
| self.clusters = pd.read_csv(f'raf_clusters/cluster_labels_sbert.csv')['cluster_label'].values |
| self.cluster_centers = pd.read_csv(f'BERT embeddings/sbert_cluster_centers.csv').values |
| self.clustered_embeddings = [self.sbert_embedding[self.clusters == i] for i in np.unique(self.clusters)] |
| else: |
| raise ValueError(f"Unsupported embedding type: {self.embedding}") |
| |
| def snippet_before_abstract(self, text): |
| pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE) |
| match = pattern.search(text) |
| if match: |
| return text[:match.start()].strip() if match.start() < 1000 else text[:100].strip() |
| else: |
| return text[:300].strip() |
|
|
| def set_embedding(self, embedding): |
| self.embedding = embedding |
|
|
| def search_function(self, query, embedding, progress=gr.Progress()): |
| self.set_embedding(embedding) |
| self.query = query |
| query = query.encode().decode('unicode_escape') |
| search_methods = { |
| "tfidf": self.keyword_match_ranking, |
| "word2vec": self.word2vec_search, |
| "bert": self.bert_search, |
| |
| "sbert": self.sbert_search, |
| "clustered sbert": self.clustered_sbert_search, |
| } |
|
|
| results = search_methods.get(self.embedding, lambda q: [])(query) |
|
|
| if not results: |
| self.last_results = [] |
| return "No results found." |
| |
| if results: |
| self.last_results = results |
|
|
| output = "" |
| display_rank = 1 |
| for idx, score in results: |
| if not self.arxiv_ids[idx]: |
| output += f"### Document {display_rank}\n" |
| output += f"<pre>{self.documents[idx][:200]}</pre>\n\n" |
| else: |
| link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}" |
| snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>') |
| output += f"### Document {display_rank}\n" |
| output += f"[arXiv Link]({link})\n\n" |
| output += f"<pre>{snippet}</pre>\n\n---\n" |
| display_rank += 1 |
|
|
| return output |
|
|
|
|
| if __name__ == "__main__": |
| dataset = load_dataset("ccdv/arxiv-classification", "no_ref") |
| search_engine = ArxivSearch(dataset) |
| search_engine.iface.launch() |