import json import re from pathlib import Path import html import unicodedata import numpy as np import streamlit as st import torch from transformers import AutoTokenizer, T5ForSequenceClassification BASE_DIR = Path(__file__).parent MODEL_DIR = BASE_DIR / "best_model" LABEL_ENCODER = BASE_DIR / "label_encoder.json" MAX_LENGTH = 256 TOP_PROB = 0.95 DEVICE = torch.device("cpu") MIN_CHARS = 20 MAX_CHARS = 5000 def clean_text(text: str) -> str: text = html.unescape(text) text = unicodedata.normalize("NFKC", text) text = re.sub(r"\$.*?\$", "", text) text = re.sub(r"\\[a-zA-Z]+\{.*?\}", "", text) text = re.sub(r"\s+", " ", text) return text.strip() def validate(title: str, abstract: str): if len(title.strip()) < MIN_CHARS: return f"Title too short (at least {MIN_CHARS} characters)" if len(abstract.strip()) < MIN_CHARS: return f"Abstract too short (at least {MIN_CHARS} characters)" if len(title) > MAX_CHARS: return f"Title too long (max {MAX_CHARS} characters)" if len(abstract) > MAX_CHARS: return f"Abstract too long (max {MAX_CHARS} characters)" return None @st.cache_resource(show_spinner="Loading model…") def load_model(): try: tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = T5ForSequenceClassification.from_pretrained(MODEL_DIR) model.to(DEVICE) model.eval() with open(LABEL_ENCODER) as f: id2label = {int(k): v for k, v in json.load(f).items()} return tokenizer, model, id2label except FileNotFoundError as e: st.error(f"Model files not found: {e}") st.stop() except Exception as e: st.error(f"Failed to load model: {e}") st.stop() @st.cache_data(show_spinner=False) def predict(title: str, abstract: str): tokenizer, model, id2label = load_model() if title and abstract: text = f"classify: {clean_text(title)} [SEP] {clean_text(abstract)[:1000]}" elif title: text = f"classify: {clean_text(title)}" else: text = f"classify: {clean_text(abstract)[:1000]}" eos_id = tokenizer.eos_token_id enc = tokenizer(text, max_length=MAX_LENGTH, truncation=True, return_tensors="pt") ids = enc["input_ids"][0].tolist() if eos_id is not None and ids.count(eos_id) > 1: ids = [t for t in ids[:-1] if t != eos_id] + [eos_id] input_ids = torch.tensor([ids], dtype=torch.long).to(DEVICE) attention_mask = torch.ones_like(input_ids) try: with torch.no_grad(): logits = model(input_ids=input_ids, attention_mask=attention_mask).logits except Exception as e: st.error(f"Model inference failed: {e}") st.stop() probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] sorted_idx = np.argsort(probs)[::-1] cumsum = np.cumsum(probs[sorted_idx]) return [ {"label": id2label.get(int(i), f"class_{i}"), "probability": float(probs[i])} for i in sorted_idx[:int(np.searchsorted(cumsum, TOP_PROB)) + 1] ] st.set_page_config(page_title="Paper Classifier", page_icon="📄", layout="centered") st.title("📄 Research Paper Classifier") st.caption("Predicts arXiv categories from a paper's title and abstract.") load_model() title = st.text_input("Title *", placeholder="e.g. Attention Is All You Need", max_chars=MAX_CHARS) abstract = st.text_area("Abstract *", placeholder="Paste the abstract here…", height=200, max_chars=MAX_CHARS) col1, col2 = st.columns([1, 4]) classify_btn = col1.button("Classify", type="primary", use_container_width=True) col2.button("Clear", on_click=lambda: None, use_container_width=True) if classify_btn: error = validate(title, abstract) if error: st.warning(error) if not clean_text(title) and not clean_text(abstract): st.warning("Please enter at least a title or an abstract.") else: with st.spinner("Classifying…"): results = predict(title, abstract) st.divider() st.subheader("Predicted categories") st.caption( f"Top-95% probability set ({len(results)} labels) " f"combined probability {sum(r['probability'] for r in results):.1%}" ) for r in results: col1, col2 = st.columns([3, 1]) col1.markdown(f"**{r['label']}**") col2.markdown(f"`{r['probability']:.1%}`") st.progress(float(r["probability"]))