Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| from paper_classifier import ( | |
| BASE_MODEL_NAME, | |
| DEFAULT_MODEL_DIR, | |
| EXAMPLES, | |
| EXPECTED_ARXIV_CATEGORIES, | |
| MAX_LENGTH, | |
| TOP_P_THRESHOLD, | |
| format_input_text, | |
| take_top_p, | |
| ) | |
| MODEL_DIR = Path(os.environ.get("ARXIV_MODEL_DIR", str(DEFAULT_MODEL_DIR))) | |
| def load_model_bundle() -> tuple[AutoTokenizer, AutoModelForSequenceClassification]: | |
| config_path = MODEL_DIR / "config.json" | |
| if not config_path.exists(): | |
| raise FileNotFoundError( | |
| f"Не найден fine-tuned checkpoint в {MODEL_DIR}. Сначала обучите модель через train_distilbert.py." | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR.as_posix()) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR.as_posix()) | |
| model.eval() | |
| return tokenizer, model | |
| def predict_topics(title: str, abstract: str) -> list[dict[str, float]]: | |
| article_text = format_input_text(title, abstract) | |
| if not article_text: | |
| raise ValueError("Введите хотя бы название статьи или abstract.") | |
| tokenizer, model = load_model_bundle() | |
| inputs = tokenizer( | |
| article_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| ) | |
| device = next(model.parameters()).device | |
| inputs = {name: tensor.to(device) for name, tensor in inputs.items()} | |
| with torch.inference_mode(): | |
| logits = model(**inputs).logits[0] | |
| probabilities = torch.softmax(logits, dim=-1).cpu().tolist() | |
| id2label = getattr(model.config, "id2label", None) or { | |
| index: f"Label {index}" for index in range(len(probabilities)) | |
| } | |
| records = [ | |
| { | |
| "label": str(id2label.get(index, f"Label {index}")), | |
| "score": float(score), | |
| } | |
| for index, score in enumerate(probabilities) | |
| ] | |
| records.sort(key=lambda record: record["score"], reverse=True) | |
| return take_top_p(records, TOP_P_THRESHOLD) | |
| def apply_styles() -> None: | |
| st.markdown( | |
| """ | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;600;700;800&family=IBM+Plex+Mono:wght@400;500&display=swap'); | |
| :root { | |
| --paper: rgba(22, 27, 34, 0.92); | |
| --card: rgba(30, 36, 46, 0.88); | |
| --ink: #e6edf3; | |
| --muted: #8b9cb3; | |
| --accent: #2dd4bf; | |
| --accent-dim: rgba(45, 212, 191, 0.14); | |
| --accent-2: #fb923c; | |
| --border: rgba(230, 237, 243, 0.09); | |
| --shadow: 0 24px 80px rgba(0, 0, 0, 0.45); | |
| --surface-0: #0d1117; | |
| --surface-1: #161b22; | |
| --surface-input: #21262d; | |
| } | |
| .stApp { | |
| background: | |
| radial-gradient(circle at 12% 8%, rgba(45, 212, 191, 0.09), transparent 32%), | |
| radial-gradient(circle at 88% 4%, rgba(251, 146, 60, 0.07), transparent 28%), | |
| linear-gradient(180deg, #0d1117 0%, #0a0e14 100%); | |
| color: var(--ink); | |
| font-family: "Manrope", sans-serif; | |
| } | |
| [data-testid="stAppViewContainer"], | |
| [data-testid="stHeader"] { | |
| background: transparent; | |
| } | |
| [data-testid="stSidebar"] { | |
| background: linear-gradient(180deg, var(--surface-1) 0%, #121820 100%); | |
| border-right: 1px solid var(--border); | |
| } | |
| [data-testid="stSidebar"] .stMarkdown, | |
| [data-testid="stSidebar"] label, | |
| [data-testid="stSidebar"] span { | |
| color: var(--ink) !important; | |
| } | |
| .block-container { | |
| padding-top: 2.2rem; | |
| padding-bottom: 2.2rem; | |
| max-width: 1100px; | |
| } | |
| section.main [data-testid="stMarkdownContainer"] p, | |
| section.main [data-testid="stMarkdownContainer"] li, | |
| section.main label, | |
| .stSubheader { | |
| color: var(--ink) !important; | |
| } | |
| .stTextInput label, | |
| .stTextArea label { | |
| color: var(--muted) !important; | |
| } | |
| .stTextInput input, | |
| .stTextArea textarea { | |
| background-color: var(--surface-input) !important; | |
| color: var(--ink) !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 12px !important; | |
| } | |
| .stTextInput input:focus, | |
| .stTextArea textarea:focus { | |
| border-color: rgba(45, 212, 191, 0.45) !important; | |
| box-shadow: 0 0 0 1px rgba(45, 212, 191, 0.25); | |
| } | |
| div[data-baseweb="select"] > div { | |
| background-color: var(--surface-input) !important; | |
| border-color: var(--border) !important; | |
| color: var(--ink) !important; | |
| } | |
| .stButton > button { | |
| background: linear-gradient(135deg, #0d9488 0%, #0f766e 100%) !important; | |
| color: #f0fdfa !important; | |
| border: none !important; | |
| font-weight: 700 !important; | |
| border-radius: 12px !important; | |
| } | |
| .stButton > button:hover { | |
| background: linear-gradient(135deg, #14b8a6 0%, #0d9488 100%) !important; | |
| color: #fff !important; | |
| } | |
| [data-testid="stExpander"] { | |
| background: var(--paper); | |
| border: 1px solid var(--border); | |
| border-radius: 14px; | |
| } | |
| [data-testid="stExpander"] summary { | |
| color: var(--ink) !important; | |
| } | |
| .stProgress > div > div { | |
| background-color: rgba(45, 212, 191, 0.35) !important; | |
| } | |
| .stProgress > div > div > div { | |
| background: linear-gradient(90deg, #2dd4bf, #14b8a6) !important; | |
| } | |
| .hero { | |
| padding: 2rem 2.2rem; | |
| border-radius: 28px; | |
| background: linear-gradient(145deg, rgba(30, 36, 46, 0.95), rgba(22, 27, 34, 0.88)); | |
| border: 1px solid var(--border); | |
| box-shadow: var(--shadow); | |
| backdrop-filter: blur(12px); | |
| margin-bottom: 1.2rem; | |
| } | |
| .hero-kicker { | |
| font-size: 0.82rem; | |
| text-transform: uppercase; | |
| letter-spacing: 0.18em; | |
| color: var(--accent); | |
| font-weight: 800; | |
| margin-bottom: 0.65rem; | |
| } | |
| .hero h1 { | |
| font-size: clamp(2rem, 3.5vw, 3.7rem); | |
| line-height: 0.98; | |
| margin: 0; | |
| max-width: 11ch; | |
| color: var(--ink); | |
| } | |
| .hero p { | |
| max-width: 56rem; | |
| color: var(--muted); | |
| font-size: 1.02rem; | |
| line-height: 1.65; | |
| margin-top: 0.95rem; | |
| margin-bottom: 0; | |
| } | |
| .info-strip { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(190px, 1fr)); | |
| gap: 0.8rem; | |
| margin: 1rem 0 1.25rem; | |
| } | |
| .info-card { | |
| padding: 1rem 1.05rem; | |
| border-radius: 20px; | |
| background: var(--paper); | |
| border: 1px solid var(--border); | |
| } | |
| .info-label { | |
| color: var(--muted); | |
| font-size: 0.84rem; | |
| margin-bottom: 0.3rem; | |
| } | |
| .info-value { | |
| font-weight: 700; | |
| color: var(--ink); | |
| word-break: break-word; | |
| } | |
| .result-card { | |
| padding: 1rem 1.1rem 1.1rem; | |
| border-radius: 22px; | |
| background: var(--card); | |
| border: 1px solid var(--border); | |
| margin-bottom: 0.9rem; | |
| } | |
| .result-rank { | |
| display: inline-block; | |
| padding: 0.2rem 0.55rem; | |
| margin-bottom: 0.65rem; | |
| border-radius: 999px; | |
| background: var(--accent-dim); | |
| color: var(--accent); | |
| font-size: 0.8rem; | |
| font-weight: 800; | |
| letter-spacing: 0.06em; | |
| text-transform: uppercase; | |
| } | |
| .result-title { | |
| font-size: 1.12rem; | |
| font-weight: 800; | |
| margin-bottom: 0.35rem; | |
| color: var(--ink); | |
| } | |
| .result-score { | |
| color: var(--accent-2); | |
| font-family: "IBM Plex Mono", monospace; | |
| font-size: 0.92rem; | |
| margin-bottom: 0.75rem; | |
| } | |
| .caption-note { | |
| color: var(--muted); | |
| font-size: 0.92rem; | |
| } | |
| [data-testid="stSidebar"] pre, | |
| [data-testid="stSidebar"] code { | |
| background-color: var(--surface-input) !important; | |
| color: #a5f3fc !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: 10px !important; | |
| } | |
| [data-testid="stSidebar"] [data-testid="stMarkdownContainer"] a { | |
| color: var(--accent) !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def render_hero() -> None: | |
| st.markdown( | |
| """ | |
| <section class="hero"> | |
| <div class="hero-kicker">Моисейин Андрей Денисович</div> | |
| <h1>Классификатор научных статей</h1> | |
| <p> | |
| Вот не зря я учил веб разработку 4 года, чтобы писать на html, css и js. Эх, был бы реакт. | |
| </p> | |
| </section> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| f""" | |
| <div class="info-strip"> | |
| <div class="info-card"> | |
| <div class="info-label">Базовая модель</div> | |
| <div class="info-value">{BASE_MODEL_NAME}</div> | |
| </div> | |
| <div class="info-card"> | |
| <div class="info-label">Checkpoint</div> | |
| <div class="info-value">{MODEL_DIR}</div> | |
| </div> | |
| <div class="info-card"> | |
| <div class="info-label">Макс. длина</div> | |
| <div class="info-value">{MAX_LENGTH} токенов</div> | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def render_results(records: list[dict[str, float]]) -> None: | |
| st.subheader("Ответ") | |
| st.caption("Классы отсортированы по убыванию вероятности. Показаны только те, которые набрали 95%.") | |
| for index, record in enumerate(records, start=1): | |
| st.markdown( | |
| f""" | |
| <div class="result-card"> | |
| <div class="result-rank">#{index}</div> | |
| <div class="result-title">{record["label"]}</div> | |
| <div class="result-score">p = {record["score"]:.2%}</div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.progress(min(max(record["score"], 0.0), 1.0)) | |
| st.caption( | |
| f"Суммарная вероятность показанных тем: {sum(record['score'] for record in records):.2%}" | |
| ) | |
| def render_sidebar() -> None: | |
| if "selected_preset" not in st.session_state: | |
| st.session_state.selected_preset = "Свой текст" | |
| if "article_title" not in st.session_state: | |
| st.session_state.article_title = "" | |
| if "article_abstract" not in st.session_state: | |
| st.session_state.article_abstract = "" | |
| st.sidebar.markdown("### Быстрый старт") | |
| preset_name = st.sidebar.selectbox( | |
| "Пример статьи", | |
| options=["Свой текст"] + list(EXAMPLES.keys()), | |
| ) | |
| if preset_name != st.session_state.selected_preset: | |
| if preset_name == "Свой текст": | |
| st.session_state.article_title = "" | |
| st.session_state.article_abstract = "" | |
| else: | |
| st.session_state.article_title = EXAMPLES[preset_name]["title"] | |
| st.session_state.article_abstract = EXAMPLES[preset_name]["abstract"] | |
| st.session_state.selected_preset = preset_name | |
| def main() -> None: | |
| st.set_page_config( | |
| page_title="Article Topic Classifier", | |
| layout="wide", | |
| ) | |
| apply_styles() | |
| render_hero() | |
| render_sidebar() | |
| left_col, right_col = st.columns([1.15, 0.85], gap="large") | |
| with left_col: | |
| with st.form("classifier-form", clear_on_submit=False): | |
| title = st.text_input( | |
| "Название статьи", | |
| key="article_title", | |
| placeholder="Например: Attention is all you need", | |
| ) | |
| abstract = st.text_area( | |
| "Абстракт", | |
| key="article_abstract", | |
| height=280, | |
| placeholder="Вставьте абстракт статьи. Если не вставишь, ну и фиг с ним.", | |
| ) | |
| submitted = st.form_submit_button("Крутить барабан (трансформер)", use_container_width=True) | |
| st.markdown( | |
| """ | |
| <div class="caption-note"> | |
| Если abstract пустой, классификация идёт только по названию. ОТВЕТ СНИЗУ. | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| if not submitted: | |
| return | |
| with st.spinner("Кручу барабан (трансформер)..."): | |
| results = predict_topics(title, abstract) | |
| render_results(results) | |
| if __name__ == "__main__": | |
| # как удобно | |
| from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx | |
| if get_script_run_ctx(suppress_warning=True) is None: | |
| import subprocess | |
| import sys | |
| raise SystemExit( | |
| subprocess.call( | |
| [sys.executable, "-m", "streamlit", "run", Path(__file__).resolve().as_posix(), *sys.argv[1:]] | |
| ) | |
| ) | |
| main() |