SHAD_Homework / src /streamlit_app.py
AndreyForty's picture
Update src/streamlit_app.py
d37b016 verified
from __future__ import annotations
import os
from pathlib import Path
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from paper_classifier import (
BASE_MODEL_NAME,
DEFAULT_MODEL_DIR,
EXAMPLES,
EXPECTED_ARXIV_CATEGORIES,
MAX_LENGTH,
TOP_P_THRESHOLD,
format_input_text,
take_top_p,
)
MODEL_DIR = Path(os.environ.get("ARXIV_MODEL_DIR", str(DEFAULT_MODEL_DIR)))
@st.cache_resource(show_spinner=False)
def load_model_bundle() -> tuple[AutoTokenizer, AutoModelForSequenceClassification]:
config_path = MODEL_DIR / "config.json"
if not config_path.exists():
raise FileNotFoundError(
f"Не найден fine-tuned checkpoint в {MODEL_DIR}. Сначала обучите модель через train_distilbert.py."
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR.as_posix())
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR.as_posix())
model.eval()
return tokenizer, model
def predict_topics(title: str, abstract: str) -> list[dict[str, float]]:
article_text = format_input_text(title, abstract)
if not article_text:
raise ValueError("Введите хотя бы название статьи или abstract.")
tokenizer, model = load_model_bundle()
inputs = tokenizer(
article_text,
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
)
device = next(model.parameters()).device
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
with torch.inference_mode():
logits = model(**inputs).logits[0]
probabilities = torch.softmax(logits, dim=-1).cpu().tolist()
id2label = getattr(model.config, "id2label", None) or {
index: f"Label {index}" for index in range(len(probabilities))
}
records = [
{
"label": str(id2label.get(index, f"Label {index}")),
"score": float(score),
}
for index, score in enumerate(probabilities)
]
records.sort(key=lambda record: record["score"], reverse=True)
return take_top_p(records, TOP_P_THRESHOLD)
def apply_styles() -> None:
st.markdown(
"""
<style>
@import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;600;700;800&family=IBM+Plex+Mono:wght@400;500&display=swap');
:root {
--paper: rgba(22, 27, 34, 0.92);
--card: rgba(30, 36, 46, 0.88);
--ink: #e6edf3;
--muted: #8b9cb3;
--accent: #2dd4bf;
--accent-dim: rgba(45, 212, 191, 0.14);
--accent-2: #fb923c;
--border: rgba(230, 237, 243, 0.09);
--shadow: 0 24px 80px rgba(0, 0, 0, 0.45);
--surface-0: #0d1117;
--surface-1: #161b22;
--surface-input: #21262d;
}
.stApp {
background:
radial-gradient(circle at 12% 8%, rgba(45, 212, 191, 0.09), transparent 32%),
radial-gradient(circle at 88% 4%, rgba(251, 146, 60, 0.07), transparent 28%),
linear-gradient(180deg, #0d1117 0%, #0a0e14 100%);
color: var(--ink);
font-family: "Manrope", sans-serif;
}
[data-testid="stAppViewContainer"],
[data-testid="stHeader"] {
background: transparent;
}
[data-testid="stSidebar"] {
background: linear-gradient(180deg, var(--surface-1) 0%, #121820 100%);
border-right: 1px solid var(--border);
}
[data-testid="stSidebar"] .stMarkdown,
[data-testid="stSidebar"] label,
[data-testid="stSidebar"] span {
color: var(--ink) !important;
}
.block-container {
padding-top: 2.2rem;
padding-bottom: 2.2rem;
max-width: 1100px;
}
section.main [data-testid="stMarkdownContainer"] p,
section.main [data-testid="stMarkdownContainer"] li,
section.main label,
.stSubheader {
color: var(--ink) !important;
}
.stTextInput label,
.stTextArea label {
color: var(--muted) !important;
}
.stTextInput input,
.stTextArea textarea {
background-color: var(--surface-input) !important;
color: var(--ink) !important;
border: 1px solid var(--border) !important;
border-radius: 12px !important;
}
.stTextInput input:focus,
.stTextArea textarea:focus {
border-color: rgba(45, 212, 191, 0.45) !important;
box-shadow: 0 0 0 1px rgba(45, 212, 191, 0.25);
}
div[data-baseweb="select"] > div {
background-color: var(--surface-input) !important;
border-color: var(--border) !important;
color: var(--ink) !important;
}
.stButton > button {
background: linear-gradient(135deg, #0d9488 0%, #0f766e 100%) !important;
color: #f0fdfa !important;
border: none !important;
font-weight: 700 !important;
border-radius: 12px !important;
}
.stButton > button:hover {
background: linear-gradient(135deg, #14b8a6 0%, #0d9488 100%) !important;
color: #fff !important;
}
[data-testid="stExpander"] {
background: var(--paper);
border: 1px solid var(--border);
border-radius: 14px;
}
[data-testid="stExpander"] summary {
color: var(--ink) !important;
}
.stProgress > div > div {
background-color: rgba(45, 212, 191, 0.35) !important;
}
.stProgress > div > div > div {
background: linear-gradient(90deg, #2dd4bf, #14b8a6) !important;
}
.hero {
padding: 2rem 2.2rem;
border-radius: 28px;
background: linear-gradient(145deg, rgba(30, 36, 46, 0.95), rgba(22, 27, 34, 0.88));
border: 1px solid var(--border);
box-shadow: var(--shadow);
backdrop-filter: blur(12px);
margin-bottom: 1.2rem;
}
.hero-kicker {
font-size: 0.82rem;
text-transform: uppercase;
letter-spacing: 0.18em;
color: var(--accent);
font-weight: 800;
margin-bottom: 0.65rem;
}
.hero h1 {
font-size: clamp(2rem, 3.5vw, 3.7rem);
line-height: 0.98;
margin: 0;
max-width: 11ch;
color: var(--ink);
}
.hero p {
max-width: 56rem;
color: var(--muted);
font-size: 1.02rem;
line-height: 1.65;
margin-top: 0.95rem;
margin-bottom: 0;
}
.info-strip {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(190px, 1fr));
gap: 0.8rem;
margin: 1rem 0 1.25rem;
}
.info-card {
padding: 1rem 1.05rem;
border-radius: 20px;
background: var(--paper);
border: 1px solid var(--border);
}
.info-label {
color: var(--muted);
font-size: 0.84rem;
margin-bottom: 0.3rem;
}
.info-value {
font-weight: 700;
color: var(--ink);
word-break: break-word;
}
.result-card {
padding: 1rem 1.1rem 1.1rem;
border-radius: 22px;
background: var(--card);
border: 1px solid var(--border);
margin-bottom: 0.9rem;
}
.result-rank {
display: inline-block;
padding: 0.2rem 0.55rem;
margin-bottom: 0.65rem;
border-radius: 999px;
background: var(--accent-dim);
color: var(--accent);
font-size: 0.8rem;
font-weight: 800;
letter-spacing: 0.06em;
text-transform: uppercase;
}
.result-title {
font-size: 1.12rem;
font-weight: 800;
margin-bottom: 0.35rem;
color: var(--ink);
}
.result-score {
color: var(--accent-2);
font-family: "IBM Plex Mono", monospace;
font-size: 0.92rem;
margin-bottom: 0.75rem;
}
.caption-note {
color: var(--muted);
font-size: 0.92rem;
}
[data-testid="stSidebar"] pre,
[data-testid="stSidebar"] code {
background-color: var(--surface-input) !important;
color: #a5f3fc !important;
border: 1px solid var(--border) !important;
border-radius: 10px !important;
}
[data-testid="stSidebar"] [data-testid="stMarkdownContainer"] a {
color: var(--accent) !important;
}
</style>
""",
unsafe_allow_html=True,
)
def render_hero() -> None:
st.markdown(
"""
<section class="hero">
<div class="hero-kicker">Моисейин Андрей Денисович</div>
<h1>Классификатор научных статей</h1>
<p>
Вот не зря я учил веб разработку 4 года, чтобы писать на html, css и js. Эх, был бы реакт.
</p>
</section>
""",
unsafe_allow_html=True,
)
st.markdown(
f"""
<div class="info-strip">
<div class="info-card">
<div class="info-label">Базовая модель</div>
<div class="info-value">{BASE_MODEL_NAME}</div>
</div>
<div class="info-card">
<div class="info-label">Checkpoint</div>
<div class="info-value">{MODEL_DIR}</div>
</div>
<div class="info-card">
<div class="info-label">Макс. длина</div>
<div class="info-value">{MAX_LENGTH} токенов</div>
</div>
</div>
""",
unsafe_allow_html=True,
)
def render_results(records: list[dict[str, float]]) -> None:
st.subheader("Ответ")
st.caption("Классы отсортированы по убыванию вероятности. Показаны только те, которые набрали 95%.")
for index, record in enumerate(records, start=1):
st.markdown(
f"""
<div class="result-card">
<div class="result-rank">#{index}</div>
<div class="result-title">{record["label"]}</div>
<div class="result-score">p = {record["score"]:.2%}</div>
</div>
""",
unsafe_allow_html=True,
)
st.progress(min(max(record["score"], 0.0), 1.0))
st.caption(
f"Суммарная вероятность показанных тем: {sum(record['score'] for record in records):.2%}"
)
def render_sidebar() -> None:
if "selected_preset" not in st.session_state:
st.session_state.selected_preset = "Свой текст"
if "article_title" not in st.session_state:
st.session_state.article_title = ""
if "article_abstract" not in st.session_state:
st.session_state.article_abstract = ""
st.sidebar.markdown("### Быстрый старт")
preset_name = st.sidebar.selectbox(
"Пример статьи",
options=["Свой текст"] + list(EXAMPLES.keys()),
)
if preset_name != st.session_state.selected_preset:
if preset_name == "Свой текст":
st.session_state.article_title = ""
st.session_state.article_abstract = ""
else:
st.session_state.article_title = EXAMPLES[preset_name]["title"]
st.session_state.article_abstract = EXAMPLES[preset_name]["abstract"]
st.session_state.selected_preset = preset_name
def main() -> None:
st.set_page_config(
page_title="Article Topic Classifier",
layout="wide",
)
apply_styles()
render_hero()
render_sidebar()
left_col, right_col = st.columns([1.15, 0.85], gap="large")
with left_col:
with st.form("classifier-form", clear_on_submit=False):
title = st.text_input(
"Название статьи",
key="article_title",
placeholder="Например: Attention is all you need",
)
abstract = st.text_area(
"Абстракт",
key="article_abstract",
height=280,
placeholder="Вставьте абстракт статьи. Если не вставишь, ну и фиг с ним.",
)
submitted = st.form_submit_button("Крутить барабан (трансформер)", use_container_width=True)
st.markdown(
"""
<div class="caption-note">
Если abstract пустой, классификация идёт только по названию. ОТВЕТ СНИЗУ.
</div>
""",
unsafe_allow_html=True,
)
if not submitted:
return
with st.spinner("Кручу барабан (трансформер)..."):
results = predict_topics(title, abstract)
render_results(results)
if __name__ == "__main__":
# как удобно
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
if get_script_run_ctx(suppress_warning=True) is None:
import subprocess
import sys
raise SystemExit(
subprocess.call(
[sys.executable, "-m", "streamlit", "run", Path(__file__).resolve().as_posix(), *sys.argv[1:]]
)
)
main()