Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import time | |
| import arxiv | |
| import joblib | |
| import streamlit as st | |
| from transformers import pipeline | |
| st.set_page_config(page_title="ArXiv Paper Classifier", page_icon="π") | |
| if "auto_title" not in st.session_state: | |
| st.session_state["auto_title"] = "" | |
| if "auto_abstract" not in st.session_state: | |
| st.session_state["auto_abstract"] = "" | |
| def load_pipeline(): | |
| model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model") | |
| return pipeline("text-classification", model=model_path, top_k=None) | |
| def load_gatekeeper(): | |
| base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| return joblib.load(os.path.join(base_dir, "ood_detector.pkl")) | |
| classifier = load_pipeline() | |
| gatekeeper = load_gatekeeper() | |
| with st.sidebar: | |
| st.subheader("About the Model") | |
| st.markdown( | |
| """ | |
| - **Base model:** `distilbert-base-uncased` | |
| - **Fine-tuning:** Balanced ArXiv dataset (ccdv/arxiv-classification) | |
| - **Task:** Classification | |
| """ | |
| ) | |
| st.info("The model is cached after the first load for fast inference on subsequent requests.") | |
| st.title("ArXiv Paper Classifier") | |
| st.write("Enter a paper's title and abstract to predict its subject category.") | |
| with st.expander("Load from link", expanded=True): | |
| arxiv_url = st.text_input("ArXiv URL", placeholder="https://arxiv.org/abs/1706.03762") | |
| if st.button("Fetch paper data"): | |
| if "arxiv.org/abs/" not in arxiv_url: | |
| st.warning("Please enter a valid ArXiv URL containing 'arxiv.org/abs/'.") | |
| else: | |
| paper_id = arxiv_url.rstrip("/").split("/")[-1].split("v")[0] | |
| with st.spinner("Fetching from ArXiv..."): | |
| try: | |
| search = arxiv.Search(id_list=[paper_id]) | |
| paper = next(search.results()) | |
| st.session_state["auto_title"] = paper.title | |
| st.session_state["auto_abstract"] = paper.summary | |
| st.success(f"Loaded: {paper.title}") | |
| except Exception as e: | |
| st.error(f"Failed to fetch paper: {e}") | |
| st.text_input("Title", key="auto_title") | |
| st.text_area("Abstract", height=200, key="auto_abstract") | |
| col_btn, col_bypass = st.columns([3, 1]) | |
| classify_clicked = col_btn.button("Classify", use_container_width=True) | |
| bypass_gatekeeper = col_bypass.toggle("β‘ Bypass Gatekeeper") | |
| if classify_clicked: | |
| title = st.session_state["auto_title"] | |
| abstract = st.session_state["auto_abstract"] | |
| if not title.strip() and not abstract.strip(): | |
| st.error("Please provide at least a title or an abstract.") | |
| st.stop() | |
| text = f"{title.strip()}. {abstract.strip()}" if title.strip() else abstract.strip() | |
| if not bypass_gatekeeper: | |
| is_science = gatekeeper.predict([text])[0] | |
| if is_science == 0: | |
| st.warning( | |
| "This text is NOT a scientific paper. Please enter a valid scientific abstract." | |
| ) | |
| st.stop() | |
| with st.spinner("Classifying paper"): | |
| start_time = time.time() | |
| predictions = classifier(text)[0] | |
| end_time = time.time() | |
| predictions.sort(key=lambda x: x["score"], reverse=True) | |
| top_predictions = [] | |
| cumulative = 0.0 | |
| for pred in predictions: | |
| top_predictions.append(pred) | |
| cumulative += pred["score"] | |
| if cumulative >= 0.95: | |
| break | |
| st.subheader("Results") | |
| for pred in top_predictions: | |
| label = pred["label"] | |
| score = pred["score"] | |
| st.write(f"**{label}** β {score * 100:.1f}%") | |
| st.progress(score) | |
| st.caption(f"Inference time: {end_time - start_time:.3f} seconds") | |
| results_json = json.dumps( | |
| [{"label": p["label"], "score": round(p["score"], 6)} for p in top_predictions], | |
| indent=2, | |
| ) | |
| st.download_button( | |
| label="Download Results JSON", | |
| data=results_json, | |
| file_name="predictions.json", | |
| mime="application/json", | |
| ) | |