paper_classification / src /streamlit_app.py
Raman Hasymau
fix
b9bfaab
import json
import os
import time
import arxiv
import joblib
import streamlit as st
from transformers import pipeline
st.set_page_config(page_title="ArXiv Paper Classifier", page_icon="πŸ“„")
if "auto_title" not in st.session_state:
st.session_state["auto_title"] = ""
if "auto_abstract" not in st.session_state:
st.session_state["auto_abstract"] = ""
@st.cache_resource
def load_pipeline():
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model")
return pipeline("text-classification", model=model_path, top_k=None)
@st.cache_resource
def load_gatekeeper():
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
return joblib.load(os.path.join(base_dir, "ood_detector.pkl"))
classifier = load_pipeline()
gatekeeper = load_gatekeeper()
with st.sidebar:
st.subheader("About the Model")
st.markdown(
"""
- **Base model:** `distilbert-base-uncased`
- **Fine-tuning:** Balanced ArXiv dataset (ccdv/arxiv-classification)
- **Task:** Classification
"""
)
st.info("The model is cached after the first load for fast inference on subsequent requests.")
st.title("ArXiv Paper Classifier")
st.write("Enter a paper's title and abstract to predict its subject category.")
with st.expander("Load from link", expanded=True):
arxiv_url = st.text_input("ArXiv URL", placeholder="https://arxiv.org/abs/1706.03762")
if st.button("Fetch paper data"):
if "arxiv.org/abs/" not in arxiv_url:
st.warning("Please enter a valid ArXiv URL containing 'arxiv.org/abs/'.")
else:
paper_id = arxiv_url.rstrip("/").split("/")[-1].split("v")[0]
with st.spinner("Fetching from ArXiv..."):
try:
search = arxiv.Search(id_list=[paper_id])
paper = next(search.results())
st.session_state["auto_title"] = paper.title
st.session_state["auto_abstract"] = paper.summary
st.success(f"Loaded: {paper.title}")
except Exception as e:
st.error(f"Failed to fetch paper: {e}")
st.text_input("Title", key="auto_title")
st.text_area("Abstract", height=200, key="auto_abstract")
col_btn, col_bypass = st.columns([3, 1])
classify_clicked = col_btn.button("Classify", use_container_width=True)
bypass_gatekeeper = col_bypass.toggle("⚑ Bypass Gatekeeper")
if classify_clicked:
title = st.session_state["auto_title"]
abstract = st.session_state["auto_abstract"]
if not title.strip() and not abstract.strip():
st.error("Please provide at least a title or an abstract.")
st.stop()
text = f"{title.strip()}. {abstract.strip()}" if title.strip() else abstract.strip()
if not bypass_gatekeeper:
is_science = gatekeeper.predict([text])[0]
if is_science == 0:
st.warning(
"This text is NOT a scientific paper. Please enter a valid scientific abstract."
)
st.stop()
with st.spinner("Classifying paper"):
start_time = time.time()
predictions = classifier(text)[0]
end_time = time.time()
predictions.sort(key=lambda x: x["score"], reverse=True)
top_predictions = []
cumulative = 0.0
for pred in predictions:
top_predictions.append(pred)
cumulative += pred["score"]
if cumulative >= 0.95:
break
st.subheader("Results")
for pred in top_predictions:
label = pred["label"]
score = pred["score"]
st.write(f"**{label}** β€” {score * 100:.1f}%")
st.progress(score)
st.caption(f"Inference time: {end_time - start_time:.3f} seconds")
results_json = json.dumps(
[{"label": p["label"], "score": round(p["score"], 6)} for p in top_predictions],
indent=2,
)
st.download_button(
label="Download Results JSON",
data=results_json,
file_name="predictions.json",
mime="application/json",
)