AndreyForty commited on
Commit
ed5c425
·
verified ·
1 Parent(s): 9eb9bb1

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +43 -20
  2. app.py +446 -0
  3. paper_classifier.py +73 -0
  4. requirements.txt +7 -3
  5. train_distilbert.py +229 -0
README.md CHANGED
@@ -1,20 +1,43 @@
1
- ---
2
- title: SHAD Homework
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Задание шад Моисейкин Андрей
12
- license: mit
13
- ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Реализация задания из ноутбука через `streamlit` и `finetune` модели
2
+ `distilbert/distilbert-base-cased` для классификации научный статей
3
+
4
+ ## Что за файлики
5
+
6
+ - `train_distilbert.py` - на датасете архива `arxivData.json` из кагле.
7
+ - `app.py` - веб-интерфейс на streamlit, который загружает уже обученный чекпонинт
8
+ - `paper_classifier.py` - общие константы, примеры
9
+
10
+
11
+ Используются поля:
12
+
13
+ - `title`
14
+ - `summary`
15
+ - `tag`
16
+
17
+
18
+ ## Обучение
19
+
20
+ ```bash
21
+ conda activate main
22
+ pip install -r requirements.txt
23
+ python train_distilbert.py
24
+ ```
25
+
26
+ По умолчанию checkpoint будет сохранён в `artifacts/distilbert-arxiv`.
27
+
28
+ ## Запуск streamlit
29
+
30
+ После обучения:
31
+
32
+ ```bash
33
+ conda activate main
34
+ streamlit run app.py --server.port 8080
35
+ ```
36
+
37
+ После запуска откройте `http://localhost:8080`.
38
+
39
+ ## Как работает инференс
40
+
41
+ - модель читает `title` и `abstract`
42
+ - если `abstract` пустой, используется только название статьи
43
+ - сервис показывает только те классы, которые суммарно набирают `95%` вероятности по категориям, иначе гг.
app.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import streamlit as st
7
+ import torch
8
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
9
+
10
+ from paper_classifier import (
11
+ BASE_MODEL_NAME,
12
+ DEFAULT_MODEL_DIR,
13
+ EXAMPLES,
14
+ EXPECTED_ARXIV_CATEGORIES,
15
+ MAX_LENGTH,
16
+ TOP_P_THRESHOLD,
17
+ format_input_text,
18
+ take_top_p,
19
+ )
20
+
21
+ MODEL_DIR = Path(os.environ.get("ARXIV_MODEL_DIR", DEFAULT_MODEL_DIR))
22
+
23
+
24
+ @st.cache_resource(show_spinner=False)
25
+ def load_model_bundle() -> tuple[AutoTokenizer, AutoModelForSequenceClassification]:
26
+ config_path = MODEL_DIR / "config.json"
27
+ if not config_path.exists():
28
+ raise FileNotFoundError(
29
+ f"Не найден fine-tuned checkpoint в {MODEL_DIR}. Сначала обучите модель через train_distilbert.py."
30
+ )
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR.as_posix())
33
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR.as_posix())
34
+ model.eval()
35
+ return tokenizer, model
36
+
37
+
38
+ def predict_topics(title: str, abstract: str) -> list[dict[str, float]]:
39
+ article_text = format_input_text(title, abstract)
40
+ if not article_text:
41
+ raise ValueError("Введите хотя бы название статьи или abstract.")
42
+
43
+ tokenizer, model = load_model_bundle()
44
+ inputs = tokenizer(
45
+ article_text,
46
+ return_tensors="pt",
47
+ truncation=True,
48
+ max_length=MAX_LENGTH,
49
+ )
50
+ device = next(model.parameters()).device
51
+ inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
52
+
53
+ with torch.inference_mode():
54
+ logits = model(**inputs).logits[0]
55
+ probabilities = torch.softmax(logits, dim=-1).cpu().tolist()
56
+
57
+ id2label = getattr(model.config, "id2label", None) or {
58
+ index: f"Label {index}" for index in range(len(probabilities))
59
+ }
60
+ records = [
61
+ {
62
+ "label": str(id2label.get(index, f"Label {index}")),
63
+ "score": float(score),
64
+ }
65
+ for index, score in enumerate(probabilities)
66
+ ]
67
+ records.sort(key=lambda record: record["score"], reverse=True)
68
+ return take_top_p(records, TOP_P_THRESHOLD)
69
+
70
+
71
+ def apply_styles() -> None:
72
+ st.markdown(
73
+ """
74
+ <style>
75
+ @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;600;700;800&family=IBM+Plex+Mono:wght@400;500&display=swap');
76
+
77
+ :root {
78
+ --paper: rgba(22, 27, 34, 0.92);
79
+ --card: rgba(30, 36, 46, 0.88);
80
+ --ink: #e6edf3;
81
+ --muted: #8b9cb3;
82
+ --accent: #2dd4bf;
83
+ --accent-dim: rgba(45, 212, 191, 0.14);
84
+ --accent-2: #fb923c;
85
+ --border: rgba(230, 237, 243, 0.09);
86
+ --shadow: 0 24px 80px rgba(0, 0, 0, 0.45);
87
+ --surface-0: #0d1117;
88
+ --surface-1: #161b22;
89
+ --surface-input: #21262d;
90
+ }
91
+
92
+ .stApp {
93
+ background:
94
+ radial-gradient(circle at 12% 8%, rgba(45, 212, 191, 0.09), transparent 32%),
95
+ radial-gradient(circle at 88% 4%, rgba(251, 146, 60, 0.07), transparent 28%),
96
+ linear-gradient(180deg, #0d1117 0%, #0a0e14 100%);
97
+ color: var(--ink);
98
+ font-family: "Manrope", sans-serif;
99
+ }
100
+
101
+ [data-testid="stAppViewContainer"],
102
+ [data-testid="stHeader"] {
103
+ background: transparent;
104
+ }
105
+
106
+ [data-testid="stSidebar"] {
107
+ background: linear-gradient(180deg, var(--surface-1) 0%, #121820 100%);
108
+ border-right: 1px solid var(--border);
109
+ }
110
+
111
+ [data-testid="stSidebar"] .stMarkdown,
112
+ [data-testid="stSidebar"] label,
113
+ [data-testid="stSidebar"] span {
114
+ color: var(--ink) !important;
115
+ }
116
+
117
+ .block-container {
118
+ padding-top: 2.2rem;
119
+ padding-bottom: 2.2rem;
120
+ max-width: 1100px;
121
+ }
122
+
123
+ section.main [data-testid="stMarkdownContainer"] p,
124
+ section.main [data-testid="stMarkdownContainer"] li,
125
+ section.main label,
126
+ .stSubheader {
127
+ color: var(--ink) !important;
128
+ }
129
+
130
+ .stTextInput label,
131
+ .stTextArea label {
132
+ color: var(--muted) !important;
133
+ }
134
+
135
+ .stTextInput input,
136
+ .stTextArea textarea {
137
+ background-color: var(--surface-input) !important;
138
+ color: var(--ink) !important;
139
+ border: 1px solid var(--border) !important;
140
+ border-radius: 12px !important;
141
+ }
142
+
143
+ .stTextInput input:focus,
144
+ .stTextArea textarea:focus {
145
+ border-color: rgba(45, 212, 191, 0.45) !important;
146
+ box-shadow: 0 0 0 1px rgba(45, 212, 191, 0.25);
147
+ }
148
+
149
+ div[data-baseweb="select"] > div {
150
+ background-color: var(--surface-input) !important;
151
+ border-color: var(--border) !important;
152
+ color: var(--ink) !important;
153
+ }
154
+
155
+ .stButton > button {
156
+ background: linear-gradient(135deg, #0d9488 0%, #0f766e 100%) !important;
157
+ color: #f0fdfa !important;
158
+ border: none !important;
159
+ font-weight: 700 !important;
160
+ border-radius: 12px !important;
161
+ }
162
+
163
+ .stButton > button:hover {
164
+ background: linear-gradient(135deg, #14b8a6 0%, #0d9488 100%) !important;
165
+ color: #fff !important;
166
+ }
167
+
168
+ [data-testid="stExpander"] {
169
+ background: var(--paper);
170
+ border: 1px solid var(--border);
171
+ border-radius: 14px;
172
+ }
173
+
174
+ [data-testid="stExpander"] summary {
175
+ color: var(--ink) !important;
176
+ }
177
+
178
+ .stProgress > div > div {
179
+ background-color: rgba(45, 212, 191, 0.35) !important;
180
+ }
181
+
182
+ .stProgress > div > div > div {
183
+ background: linear-gradient(90deg, #2dd4bf, #14b8a6) !important;
184
+ }
185
+
186
+ .hero {
187
+ padding: 2rem 2.2rem;
188
+ border-radius: 28px;
189
+ background: linear-gradient(145deg, rgba(30, 36, 46, 0.95), rgba(22, 27, 34, 0.88));
190
+ border: 1px solid var(--border);
191
+ box-shadow: var(--shadow);
192
+ backdrop-filter: blur(12px);
193
+ margin-bottom: 1.2rem;
194
+ }
195
+
196
+ .hero-kicker {
197
+ font-size: 0.82rem;
198
+ text-transform: uppercase;
199
+ letter-spacing: 0.18em;
200
+ color: var(--accent);
201
+ font-weight: 800;
202
+ margin-bottom: 0.65rem;
203
+ }
204
+
205
+ .hero h1 {
206
+ font-size: clamp(2rem, 3.5vw, 3.7rem);
207
+ line-height: 0.98;
208
+ margin: 0;
209
+ max-width: 11ch;
210
+ color: var(--ink);
211
+ }
212
+
213
+ .hero p {
214
+ max-width: 56rem;
215
+ color: var(--muted);
216
+ font-size: 1.02rem;
217
+ line-height: 1.65;
218
+ margin-top: 0.95rem;
219
+ margin-bottom: 0;
220
+ }
221
+
222
+ .info-strip {
223
+ display: grid;
224
+ grid-template-columns: repeat(auto-fit, minmax(190px, 1fr));
225
+ gap: 0.8rem;
226
+ margin: 1rem 0 1.25rem;
227
+ }
228
+
229
+ .info-card {
230
+ padding: 1rem 1.05rem;
231
+ border-radius: 20px;
232
+ background: var(--paper);
233
+ border: 1px solid var(--border);
234
+ }
235
+
236
+ .info-label {
237
+ color: var(--muted);
238
+ font-size: 0.84rem;
239
+ margin-bottom: 0.3rem;
240
+ }
241
+
242
+ .info-value {
243
+ font-weight: 700;
244
+ color: var(--ink);
245
+ word-break: break-word;
246
+ }
247
+
248
+ .result-card {
249
+ padding: 1rem 1.1rem 1.1rem;
250
+ border-radius: 22px;
251
+ background: var(--card);
252
+ border: 1px solid var(--border);
253
+ margin-bottom: 0.9rem;
254
+ }
255
+
256
+ .result-rank {
257
+ display: inline-block;
258
+ padding: 0.2rem 0.55rem;
259
+ margin-bottom: 0.65rem;
260
+ border-radius: 999px;
261
+ background: var(--accent-dim);
262
+ color: var(--accent);
263
+ font-size: 0.8rem;
264
+ font-weight: 800;
265
+ letter-spacing: 0.06em;
266
+ text-transform: uppercase;
267
+ }
268
+
269
+ .result-title {
270
+ font-size: 1.12rem;
271
+ font-weight: 800;
272
+ margin-bottom: 0.35rem;
273
+ color: var(--ink);
274
+ }
275
+
276
+ .result-score {
277
+ color: var(--accent-2);
278
+ font-family: "IBM Plex Mono", monospace;
279
+ font-size: 0.92rem;
280
+ margin-bottom: 0.75rem;
281
+ }
282
+
283
+ .caption-note {
284
+ color: var(--muted);
285
+ font-size: 0.92rem;
286
+ }
287
+
288
+ [data-testid="stSidebar"] pre,
289
+ [data-testid="stSidebar"] code {
290
+ background-color: var(--surface-input) !important;
291
+ color: #a5f3fc !important;
292
+ border: 1px solid var(--border) !important;
293
+ border-radius: 10px !important;
294
+ }
295
+
296
+ [data-testid="stSidebar"] [data-testid="stMarkdownContainer"] a {
297
+ color: var(--accent) !important;
298
+ }
299
+ </style>
300
+ """,
301
+ unsafe_allow_html=True,
302
+ )
303
+
304
+
305
+ def render_hero() -> None:
306
+ st.markdown(
307
+ """
308
+ <section class="hero">
309
+ <div class="hero-kicker">Моисейин Андрей Денисович</div>
310
+ <h1>Классификатор научных статей</h1>
311
+ <p>
312
+ Вот не зря я учил веб разработку 4 года, чтобы писать на html, css и js. Эх, был бы реакт.
313
+ </p>
314
+ </section>
315
+ """,
316
+ unsafe_allow_html=True,
317
+ )
318
+
319
+ st.markdown(
320
+ f"""
321
+ <div class="info-strip">
322
+ <div class="info-card">
323
+ <div class="info-label">Базовая модель</div>
324
+ <div class="info-value">{BASE_MODEL_NAME}</div>
325
+ </div>
326
+ <div class="info-card">
327
+ <div class="info-label">Checkpoint</div>
328
+ <div class="info-value">{MODEL_DIR}</div>
329
+ </div>
330
+ <div class="info-card">
331
+ <div class="info-label">Макс. длина</div>
332
+ <div class="info-value">{MAX_LENGTH} токенов</div>
333
+ </div>
334
+ </div>
335
+ """,
336
+ unsafe_allow_html=True,
337
+ )
338
+
339
+
340
+ def render_results(records: list[dict[str, float]]) -> None:
341
+ st.subheader("Ответ")
342
+ st.caption("Классы отсортированы по убыванию вероятности. Показаны только те, которые набрали 95%.")
343
+
344
+ for index, record in enumerate(records, start=1):
345
+ st.markdown(
346
+ f"""
347
+ <div class="result-card">
348
+ <div class="result-rank">#{index}</div>
349
+ <div class="result-title">{record["label"]}</div>
350
+ <div class="result-score">p = {record["score"]:.2%}</div>
351
+ </div>
352
+ """,
353
+ unsafe_allow_html=True,
354
+ )
355
+ st.progress(min(max(record["score"], 0.0), 1.0))
356
+
357
+ st.caption(
358
+ f"Суммарная вероятность показанных тем: {sum(record['score'] for record in records):.2%}"
359
+ )
360
+
361
+
362
+ def render_sidebar() -> None:
363
+ if "selected_preset" not in st.session_state:
364
+ st.session_state.selected_preset = "Свой текст"
365
+ if "article_title" not in st.session_state:
366
+ st.session_state.article_title = ""
367
+ if "article_abstract" not in st.session_state:
368
+ st.session_state.article_abstract = ""
369
+
370
+ st.sidebar.markdown("### Быстрый старт")
371
+ preset_name = st.sidebar.selectbox(
372
+ "Пример статьи",
373
+ options=["Свой текст"] + list(EXAMPLES.keys()),
374
+ )
375
+
376
+ if preset_name != st.session_state.selected_preset:
377
+ if preset_name == "Свой текст":
378
+ st.session_state.article_title = ""
379
+ st.session_state.article_abstract = ""
380
+ else:
381
+ st.session_state.article_title = EXAMPLES[preset_name]["title"]
382
+ st.session_state.article_abstract = EXAMPLES[preset_name]["abstract"]
383
+ st.session_state.selected_preset = preset_name
384
+
385
+
386
+
387
+
388
+ def main() -> None:
389
+ st.set_page_config(
390
+ page_title="Article Topic Classifier",
391
+ layout="wide",
392
+ )
393
+ apply_styles()
394
+ render_hero()
395
+ render_sidebar()
396
+
397
+ left_col, right_col = st.columns([1.15, 0.85], gap="large")
398
+
399
+ with left_col:
400
+ with st.form("classifier-form", clear_on_submit=False):
401
+ title = st.text_input(
402
+ "Название статьи",
403
+ key="article_title",
404
+ placeholder="Например: Attention is all you need",
405
+ )
406
+ abstract = st.text_area(
407
+ "Абстракт",
408
+ key="article_abstract",
409
+ height=280,
410
+ placeholder="Вставьте абстракт статьи. Если не вставишь, ну и фиг с ним.",
411
+ )
412
+ submitted = st.form_submit_button("Крутить барабан (трансформер)", use_container_width=True)
413
+
414
+ st.markdown(
415
+ """
416
+ <div class="caption-note">
417
+ Если abstract пустой, классификация идёт только по названию. ОТВЕТ СНИЗУ.
418
+ </div>
419
+ """,
420
+ unsafe_allow_html=True,
421
+ )
422
+
423
+ if not submitted:
424
+ return
425
+
426
+ with st.spinner("Кручу барабан (трансформер)..."):
427
+ results = predict_topics(title, abstract)
428
+
429
+ render_results(results)
430
+
431
+
432
+ if __name__ == "__main__":
433
+
434
+ # как удобно
435
+ from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
436
+
437
+ if get_script_run_ctx(suppress_warning=True) is None:
438
+ import subprocess
439
+ import sys
440
+
441
+ raise SystemExit(
442
+ subprocess.call(
443
+ [sys.executable, "-m", "streamlit", "run", Path(__file__).resolve().as_posix(), *sys.argv[1:]]
444
+ )
445
+ )
446
+ main()
paper_classifier.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable
4
+
5
+ BASE_MODEL_NAME = "distilbert/distilbert-base-cased"
6
+ DEFAULT_MODEL_DIR = "artifacts/distilbert-arxiv"
7
+ MAX_LENGTH = 256
8
+ TOP_P_THRESHOLD = 0.95
9
+ EXPECTED_ARXIV_CATEGORIES = [
10
+ "Computer Science",
11
+ "Physics",
12
+ "Mathematics",
13
+ "Statistics",
14
+ "Quantitative Biology",
15
+ "Quantitative Finance",
16
+ "Economics",
17
+ "Electrical Engineering and Systems Science",
18
+ ]
19
+ EXAMPLES = {
20
+ "Graph Neural Networks": {
21
+ "title": "Message Passing Neural Networks for Molecular Property Prediction",
22
+ "abstract": (
23
+ "We introduce a graph-based neural architecture for supervised learning on "
24
+ "molecular graphs. The model propagates messages between atoms, aggregates "
25
+ "node states into a graph embedding, and predicts physical and chemical "
26
+ "properties with competitive accuracy."
27
+ ),
28
+ },
29
+ "Physics": {
30
+ "title": "Topological phase transitions in two-dimensional quantum materials",
31
+ "abstract": (
32
+ "We study a lattice model with strong spin-orbit coupling and show how "
33
+ "interactions modify the phase diagram. Using numerical simulations we "
34
+ "characterize edge states, quantify transport signatures, and discuss "
35
+ "observable consequences for low-temperature experiments."
36
+ ),
37
+ },
38
+ "Bioinformatics": {
39
+ "title": "Transformer models for protein function annotation from sequence",
40
+ "abstract": (
41
+ "We pretrain a transformer encoder on amino acid sequences and finetune it "
42
+ "for protein function prediction. The approach improves annotation quality "
43
+ "for underrepresented families and reveals biologically meaningful sequence "
44
+ "patterns."
45
+ ),
46
+ },
47
+ }
48
+
49
+
50
+ def format_input_text(title: str, abstract: str) -> str:
51
+ title = title.strip()
52
+ abstract = abstract.strip()
53
+
54
+ parts: list[str] = []
55
+ if title:
56
+ parts.append(f"Title: {title}\nTitle summary: {title}")
57
+ if abstract:
58
+ parts.append(f"Abstract: {abstract}")
59
+
60
+ return "\n\n".join(parts)
61
+
62
+
63
+ def take_top_p(records: Iterable[dict[str, float]], threshold: float) -> list[dict[str, float]]:
64
+ selected: list[dict[str, float]] = []
65
+ cumulative = 0.0
66
+
67
+ for record in records:
68
+ selected.append(record)
69
+ cumulative += record["score"]
70
+ if cumulative >= threshold:
71
+ break
72
+
73
+ return selected
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch
4
+ safetensors
5
+ datasets
6
+ accelerate
7
+ scikit-learn
train_distilbert.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import json
5
+ from collections import Counter
6
+ from functools import partial
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from datasets import Dataset, DatasetDict
11
+ from sklearn.metrics import accuracy_score, f1_score
12
+ from transformers import (
13
+ AutoModelForSequenceClassification,
14
+ AutoTokenizer,
15
+ DataCollatorWithPadding,
16
+ Trainer,
17
+ TrainingArguments,
18
+ set_seed,
19
+ )
20
+
21
+ from paper_classifier import BASE_MODEL_NAME, DEFAULT_MODEL_DIR, MAX_LENGTH, format_input_text
22
+
23
+ DATA_PATH = Path("arxivData.json")
24
+ OUTPUT_DIR = Path(DEFAULT_MODEL_DIR)
25
+ HF_CACHE_DIR = Path("/tmp/huggingface")
26
+
27
+ TITLE_FIELD = "title"
28
+ ABSTRACT_FIELD = "summary"
29
+ TAG_FIELD = "tag"
30
+
31
+ VALIDATION_SIZE = 0.1
32
+ NUM_TRAIN_EPOCHS = 4
33
+ LEARNING_RATE = 2e-5
34
+ WEIGHT_DECAY = 0.01
35
+ PER_DEVICE_TRAIN_BATCH_SIZE = 16
36
+ PER_DEVICE_EVAL_BATCH_SIZE = 32
37
+ LOGGING_STEPS = 50
38
+ SEED = 42
39
+
40
+ PREFIX_TO_LABEL = {
41
+ "adap-org": "Quantitative Biology",
42
+ "astro-ph": "Physics",
43
+ "cmp-lg": "Computer Science",
44
+ "cond-mat": "Physics",
45
+ "cs": "Computer Science",
46
+ "econ": "Economics",
47
+ "eess": "Electrical Engineering and Systems Science",
48
+ "gr-qc": "Physics",
49
+ "hep-ex": "Physics",
50
+ "hep-lat": "Physics",
51
+ "hep-ph": "Physics",
52
+ "hep-th": "Physics",
53
+ "math": "Mathematics",
54
+ "nlin": "Physics",
55
+ "nucl-th": "Physics",
56
+ "physics": "Physics",
57
+ "q-bio": "Quantitative Biology",
58
+ "q-fin": "Quantitative Finance",
59
+ "quant-ph": "Physics",
60
+ "stat": "Statistics",
61
+ }
62
+
63
+
64
+ def normalize_text(value):
65
+ return " ".join(str(value or "").split())
66
+
67
+
68
+ def parse_top_level_label(raw_tag):
69
+ if not raw_tag:
70
+ return None
71
+
72
+ try:
73
+ parsed_tags = ast.literal_eval(str(raw_tag))
74
+ except (SyntaxError, ValueError):
75
+ return None
76
+
77
+ if not isinstance(parsed_tags, list):
78
+ return None
79
+
80
+ for tag in parsed_tags:
81
+ if not isinstance(tag, dict):
82
+ continue
83
+ term = tag.get("term")
84
+ if not term:
85
+ continue
86
+ prefix = str(term).split(".")[0]
87
+ label = PREFIX_TO_LABEL.get(prefix)
88
+ if label:
89
+ return label
90
+
91
+ return None
92
+
93
+
94
+ def build_records():
95
+ with DATA_PATH.open("r", encoding="utf-8") as file:
96
+ raw_records = json.load(file)
97
+
98
+ prepared_records: list[dict[str, str]] = []
99
+ skipped = Counter()
100
+
101
+ for item in raw_records:
102
+ title = normalize_text(item.get(TITLE_FIELD))
103
+ abstract = normalize_text(item.get(ABSTRACT_FIELD))
104
+ label = parse_top_level_label(item.get(TAG_FIELD))
105
+ text = format_input_text(title, abstract)
106
+ prepared_records.append(
107
+ {
108
+ "text": text,
109
+ "label": label,
110
+ }
111
+ )
112
+
113
+ print(f"Loaded {len(prepared_records)}")
114
+
115
+ label_distribution = Counter(record["label"] for record in prepared_records)
116
+ print("Label distribution:", dict(label_distribution))
117
+ return prepared_records
118
+
119
+
120
+ def build_splits(records):
121
+ dataset = Dataset.from_list(records)
122
+ split = dataset.train_test_split(test_size=VALIDATION_SIZE, seed=SEED)
123
+ return DatasetDict(train=split["train"], validation=split["test"])
124
+
125
+
126
+ def preprocess(batch, *, tokenizer, label2id):
127
+ tokenized = tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH)
128
+ tokenized["labels"] = [label2id[label] for label in batch["label"]]
129
+ return tokenized
130
+
131
+
132
+ def compute_metrics(eval_prediction):
133
+ logits, labels = eval_prediction
134
+ predictions = np.argmax(logits, axis=-1)
135
+ return {
136
+ "accuracy": accuracy_score(labels, predictions),
137
+ "macro_f1": f1_score(labels, predictions, average="macro"),
138
+ }
139
+
140
+
141
+ def main() -> None:
142
+ if not DATA_PATH.exists():
143
+ raise FileNotFoundError(f"Dataset file not found: {DATA_PATH}")
144
+
145
+ set_seed(SEED)
146
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
147
+ HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
148
+
149
+ records = build_records()
150
+ raw_splits = build_splits(records)
151
+
152
+ label_names = sorted({record["label"] for record in records})
153
+ label2id = {label: index for index, label in enumerate(label_names)}
154
+ id2label = {index: label for label, index in label2id.items()}
155
+
156
+ tokenizer = AutoTokenizer.from_pretrained(
157
+ BASE_MODEL_NAME,
158
+ cache_dir=HF_CACHE_DIR.as_posix(),
159
+ )
160
+
161
+ tokenized_splits = raw_splits.map(
162
+ partial(preprocess, tokenizer=tokenizer, label2id=label2id),
163
+ batched=True,
164
+ remove_columns=raw_splits["train"].column_names,
165
+ )
166
+
167
+ model = AutoModelForSequenceClassification.from_pretrained(
168
+ BASE_MODEL_NAME,
169
+ cache_dir=HF_CACHE_DIR.as_posix(),
170
+ num_labels=len(label_names),
171
+ id2label=id2label,
172
+ label2id=label2id,
173
+ )
174
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
175
+
176
+ training_args = TrainingArguments(
177
+ output_dir=OUTPUT_DIR.as_posix(),
178
+ do_train=True,
179
+ do_eval=True,
180
+ eval_strategy="epoch",
181
+ save_strategy="epoch",
182
+ logging_strategy="steps",
183
+ logging_steps=LOGGING_STEPS,
184
+ learning_rate=LEARNING_RATE,
185
+ weight_decay=WEIGHT_DECAY,
186
+ num_train_epochs=NUM_TRAIN_EPOCHS,
187
+ per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
188
+ per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
189
+ load_best_model_at_end=True,
190
+ metric_for_best_model="macro_f1",
191
+ greater_is_better=True,
192
+ save_total_limit=2,
193
+ report_to=[],
194
+ seed=SEED,
195
+ )
196
+
197
+ trainer = Trainer(
198
+ model=model,
199
+ args=training_args,
200
+ train_dataset=tokenized_splits["train"],
201
+ eval_dataset=tokenized_splits["validation"],
202
+ processing_class=tokenizer,
203
+ data_collator=data_collator,
204
+ compute_metrics=compute_metrics,
205
+ )
206
+
207
+ trainer.train()
208
+ metrics = trainer.evaluate()
209
+ trainer.save_model(OUTPUT_DIR.as_posix())
210
+ tokenizer.save_pretrained(OUTPUT_DIR.as_posix())
211
+
212
+ summary_path = OUTPUT_DIR / "training_summary.json"
213
+ summary = {
214
+ "base_model": BASE_MODEL_NAME,
215
+ "data_path": DATA_PATH.as_posix(),
216
+ "output_dir": OUTPUT_DIR.as_posix(),
217
+ "title_field": TITLE_FIELD,
218
+ "abstract_field": ABSTRACT_FIELD,
219
+ "tag_field": TAG_FIELD,
220
+ "labels": label_names,
221
+ "metrics": metrics,
222
+ }
223
+ summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
224
+
225
+ print(json.dumps(summary, indent=2))
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()