Spaces:
Running
Running
| import json | |
| import re | |
| from pathlib import Path | |
| import html | |
| import unicodedata | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, T5ForSequenceClassification | |
| BASE_DIR = Path(__file__).parent | |
| MODEL_DIR = BASE_DIR / "best_model" | |
| LABEL_ENCODER = BASE_DIR / "label_encoder.json" | |
| MAX_LENGTH = 256 | |
| TOP_PROB = 0.95 | |
| DEVICE = torch.device("cpu") | |
| MIN_CHARS = 20 | |
| MAX_CHARS = 5000 | |
| def clean_text(text: str) -> str: | |
| text = html.unescape(text) | |
| text = unicodedata.normalize("NFKC", text) | |
| text = re.sub(r"\$.*?\$", "", text) | |
| text = re.sub(r"\\[a-zA-Z]+\{.*?\}", "", text) | |
| text = re.sub(r"\s+", " ", text) | |
| return text.strip() | |
| def validate(title: str, abstract: str): | |
| if len(title.strip()) < MIN_CHARS: | |
| return f"Title too short (at least {MIN_CHARS} characters)" | |
| if len(abstract.strip()) < MIN_CHARS: | |
| return f"Abstract too short (at least {MIN_CHARS} characters)" | |
| if len(title) > MAX_CHARS: | |
| return f"Title too long (max {MAX_CHARS} characters)" | |
| if len(abstract) > MAX_CHARS: | |
| return f"Abstract too long (max {MAX_CHARS} characters)" | |
| return None | |
| def load_model(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
| model = T5ForSequenceClassification.from_pretrained(MODEL_DIR) | |
| model.to(DEVICE) | |
| model.eval() | |
| with open(LABEL_ENCODER) as f: | |
| id2label = {int(k): v for k, v in json.load(f).items()} | |
| return tokenizer, model, id2label | |
| except FileNotFoundError as e: | |
| st.error(f"Model files not found: {e}") | |
| st.stop() | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| st.stop() | |
| def predict(title: str, abstract: str): | |
| tokenizer, model, id2label = load_model() | |
| if title and abstract: | |
| text = f"classify: {clean_text(title)} [SEP] {clean_text(abstract)[:1000]}" | |
| elif title: | |
| text = f"classify: {clean_text(title)}" | |
| else: | |
| text = f"classify: {clean_text(abstract)[:1000]}" | |
| eos_id = tokenizer.eos_token_id | |
| enc = tokenizer(text, max_length=MAX_LENGTH, truncation=True, return_tensors="pt") | |
| ids = enc["input_ids"][0].tolist() | |
| if eos_id is not None and ids.count(eos_id) > 1: | |
| ids = [t for t in ids[:-1] if t != eos_id] + [eos_id] | |
| input_ids = torch.tensor([ids], dtype=torch.long).to(DEVICE) | |
| attention_mask = torch.ones_like(input_ids) | |
| try: | |
| with torch.no_grad(): | |
| logits = model(input_ids=input_ids, attention_mask=attention_mask).logits | |
| except Exception as e: | |
| st.error(f"Model inference failed: {e}") | |
| st.stop() | |
| probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] | |
| sorted_idx = np.argsort(probs)[::-1] | |
| cumsum = np.cumsum(probs[sorted_idx]) | |
| return [ | |
| {"label": id2label.get(int(i), f"class_{i}"), "probability": float(probs[i])} | |
| for i in sorted_idx[:int(np.searchsorted(cumsum, TOP_PROB)) + 1] | |
| ] | |
| st.set_page_config(page_title="Paper Classifier", page_icon="📄", layout="centered") | |
| st.title("📄 Research Paper Classifier") | |
| st.caption("Predicts arXiv categories from a paper's title and abstract.") | |
| load_model() | |
| title = st.text_input("Title *", placeholder="e.g. Attention Is All You Need", max_chars=MAX_CHARS) | |
| abstract = st.text_area("Abstract *", placeholder="Paste the abstract here…", height=200, max_chars=MAX_CHARS) | |
| col1, col2 = st.columns([1, 4]) | |
| classify_btn = col1.button("Classify", type="primary", use_container_width=True) | |
| col2.button("Clear", on_click=lambda: None, use_container_width=True) | |
| if classify_btn: | |
| error = validate(title, abstract) | |
| if error: | |
| st.warning(error) | |
| if not clean_text(title) and not clean_text(abstract): | |
| st.warning("Please enter at least a title or an abstract.") | |
| else: | |
| with st.spinner("Classifying…"): | |
| results = predict(title, abstract) | |
| st.divider() | |
| st.subheader("Predicted categories") | |
| st.caption( | |
| f"Top-95% probability set ({len(results)} labels) " | |
| f"combined probability {sum(r['probability'] for r in results):.1%}" | |
| ) | |
| for r in results: | |
| col1, col2 = st.columns([3, 1]) | |
| col1.markdown(f"**{r['label']}**") | |
| col2.markdown(f"`{r['probability']:.1%}`") | |
| st.progress(float(r["probability"])) | |