paper_labeling / src /streamlit_app.py
sastsy's picture
Update src/streamlit_app.py
68b3862 verified
import json
import re
from pathlib import Path
import html
import unicodedata
import numpy as np
import streamlit as st
import torch
from transformers import AutoTokenizer, T5ForSequenceClassification
BASE_DIR = Path(__file__).parent
MODEL_DIR = BASE_DIR / "best_model"
LABEL_ENCODER = BASE_DIR / "label_encoder.json"
MAX_LENGTH = 256
TOP_PROB = 0.95
DEVICE = torch.device("cpu")
MIN_CHARS = 20
MAX_CHARS = 5000
def clean_text(text: str) -> str:
text = html.unescape(text)
text = unicodedata.normalize("NFKC", text)
text = re.sub(r"\$.*?\$", "", text)
text = re.sub(r"\\[a-zA-Z]+\{.*?\}", "", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def validate(title: str, abstract: str):
if len(title.strip()) < MIN_CHARS:
return f"Title too short (at least {MIN_CHARS} characters)"
if len(abstract.strip()) < MIN_CHARS:
return f"Abstract too short (at least {MIN_CHARS} characters)"
if len(title) > MAX_CHARS:
return f"Title too long (max {MAX_CHARS} characters)"
if len(abstract) > MAX_CHARS:
return f"Abstract too long (max {MAX_CHARS} characters)"
return None
@st.cache_resource(show_spinner="Loading model…")
def load_model():
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = T5ForSequenceClassification.from_pretrained(MODEL_DIR)
model.to(DEVICE)
model.eval()
with open(LABEL_ENCODER) as f:
id2label = {int(k): v for k, v in json.load(f).items()}
return tokenizer, model, id2label
except FileNotFoundError as e:
st.error(f"Model files not found: {e}")
st.stop()
except Exception as e:
st.error(f"Failed to load model: {e}")
st.stop()
@st.cache_data(show_spinner=False)
def predict(title: str, abstract: str):
tokenizer, model, id2label = load_model()
if title and abstract:
text = f"classify: {clean_text(title)} [SEP] {clean_text(abstract)[:1000]}"
elif title:
text = f"classify: {clean_text(title)}"
else:
text = f"classify: {clean_text(abstract)[:1000]}"
eos_id = tokenizer.eos_token_id
enc = tokenizer(text, max_length=MAX_LENGTH, truncation=True, return_tensors="pt")
ids = enc["input_ids"][0].tolist()
if eos_id is not None and ids.count(eos_id) > 1:
ids = [t for t in ids[:-1] if t != eos_id] + [eos_id]
input_ids = torch.tensor([ids], dtype=torch.long).to(DEVICE)
attention_mask = torch.ones_like(input_ids)
try:
with torch.no_grad():
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
except Exception as e:
st.error(f"Model inference failed: {e}")
st.stop()
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
sorted_idx = np.argsort(probs)[::-1]
cumsum = np.cumsum(probs[sorted_idx])
return [
{"label": id2label.get(int(i), f"class_{i}"), "probability": float(probs[i])}
for i in sorted_idx[:int(np.searchsorted(cumsum, TOP_PROB)) + 1]
]
st.set_page_config(page_title="Paper Classifier", page_icon="📄", layout="centered")
st.title("📄 Research Paper Classifier")
st.caption("Predicts arXiv categories from a paper's title and abstract.")
load_model()
title = st.text_input("Title *", placeholder="e.g. Attention Is All You Need", max_chars=MAX_CHARS)
abstract = st.text_area("Abstract *", placeholder="Paste the abstract here…", height=200, max_chars=MAX_CHARS)
col1, col2 = st.columns([1, 4])
classify_btn = col1.button("Classify", type="primary", use_container_width=True)
col2.button("Clear", on_click=lambda: None, use_container_width=True)
if classify_btn:
error = validate(title, abstract)
if error:
st.warning(error)
if not clean_text(title) and not clean_text(abstract):
st.warning("Please enter at least a title or an abstract.")
else:
with st.spinner("Classifying…"):
results = predict(title, abstract)
st.divider()
st.subheader("Predicted categories")
st.caption(
f"Top-95% probability set ({len(results)} labels) "
f"combined probability {sum(r['probability'] for r in results):.1%}"
)
for r in results:
col1, col2 = st.columns([3, 1])
col1.markdown(f"**{r['label']}**")
col2.markdown(f"`{r['probability']:.1%}`")
st.progress(float(r["probability"]))