AndreyForty commited on
Commit
af53d97
·
verified ·
1 Parent(s): 2f856a9

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/paper_classifier.py +73 -0
  2. src/train_distilbert.py +229 -0
src/paper_classifier.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable
4
+
5
+ BASE_MODEL_NAME = "distilbert/distilbert-base-cased"
6
+ DEFAULT_MODEL_DIR = "artifacts/distilbert-arxiv"
7
+ MAX_LENGTH = 256
8
+ TOP_P_THRESHOLD = 0.95
9
+ EXPECTED_ARXIV_CATEGORIES = [
10
+ "Computer Science",
11
+ "Physics",
12
+ "Mathematics",
13
+ "Statistics",
14
+ "Quantitative Biology",
15
+ "Quantitative Finance",
16
+ "Economics",
17
+ "Electrical Engineering and Systems Science",
18
+ ]
19
+ EXAMPLES = {
20
+ "Graph Neural Networks": {
21
+ "title": "Message Passing Neural Networks for Molecular Property Prediction",
22
+ "abstract": (
23
+ "We introduce a graph-based neural architecture for supervised learning on "
24
+ "molecular graphs. The model propagates messages between atoms, aggregates "
25
+ "node states into a graph embedding, and predicts physical and chemical "
26
+ "properties with competitive accuracy."
27
+ ),
28
+ },
29
+ "Physics": {
30
+ "title": "Topological phase transitions in two-dimensional quantum materials",
31
+ "abstract": (
32
+ "We study a lattice model with strong spin-orbit coupling and show how "
33
+ "interactions modify the phase diagram. Using numerical simulations we "
34
+ "characterize edge states, quantify transport signatures, and discuss "
35
+ "observable consequences for low-temperature experiments."
36
+ ),
37
+ },
38
+ "Bioinformatics": {
39
+ "title": "Transformer models for protein function annotation from sequence",
40
+ "abstract": (
41
+ "We pretrain a transformer encoder on amino acid sequences and finetune it "
42
+ "for protein function prediction. The approach improves annotation quality "
43
+ "for underrepresented families and reveals biologically meaningful sequence "
44
+ "patterns."
45
+ ),
46
+ },
47
+ }
48
+
49
+
50
+ def format_input_text(title: str, abstract: str) -> str:
51
+ title = title.strip()
52
+ abstract = abstract.strip()
53
+
54
+ parts: list[str] = []
55
+ if title:
56
+ parts.append(f"Title: {title}\nTitle summary: {title}")
57
+ if abstract:
58
+ parts.append(f"Abstract: {abstract}")
59
+
60
+ return "\n\n".join(parts)
61
+
62
+
63
+ def take_top_p(records: Iterable[dict[str, float]], threshold: float) -> list[dict[str, float]]:
64
+ selected: list[dict[str, float]] = []
65
+ cumulative = 0.0
66
+
67
+ for record in records:
68
+ selected.append(record)
69
+ cumulative += record["score"]
70
+ if cumulative >= threshold:
71
+ break
72
+
73
+ return selected
src/train_distilbert.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import json
5
+ from collections import Counter
6
+ from functools import partial
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ from datasets import Dataset, DatasetDict
11
+ from sklearn.metrics import accuracy_score, f1_score
12
+ from transformers import (
13
+ AutoModelForSequenceClassification,
14
+ AutoTokenizer,
15
+ DataCollatorWithPadding,
16
+ Trainer,
17
+ TrainingArguments,
18
+ set_seed,
19
+ )
20
+
21
+ from paper_classifier import BASE_MODEL_NAME, DEFAULT_MODEL_DIR, MAX_LENGTH, format_input_text
22
+
23
+ DATA_PATH = Path("arxivData.json")
24
+ OUTPUT_DIR = Path(DEFAULT_MODEL_DIR)
25
+ HF_CACHE_DIR = Path("/tmp/huggingface")
26
+
27
+ TITLE_FIELD = "title"
28
+ ABSTRACT_FIELD = "summary"
29
+ TAG_FIELD = "tag"
30
+
31
+ VALIDATION_SIZE = 0.1
32
+ NUM_TRAIN_EPOCHS = 4
33
+ LEARNING_RATE = 2e-5
34
+ WEIGHT_DECAY = 0.01
35
+ PER_DEVICE_TRAIN_BATCH_SIZE = 16
36
+ PER_DEVICE_EVAL_BATCH_SIZE = 32
37
+ LOGGING_STEPS = 50
38
+ SEED = 42
39
+
40
+ PREFIX_TO_LABEL = {
41
+ "adap-org": "Quantitative Biology",
42
+ "astro-ph": "Physics",
43
+ "cmp-lg": "Computer Science",
44
+ "cond-mat": "Physics",
45
+ "cs": "Computer Science",
46
+ "econ": "Economics",
47
+ "eess": "Electrical Engineering and Systems Science",
48
+ "gr-qc": "Physics",
49
+ "hep-ex": "Physics",
50
+ "hep-lat": "Physics",
51
+ "hep-ph": "Physics",
52
+ "hep-th": "Physics",
53
+ "math": "Mathematics",
54
+ "nlin": "Physics",
55
+ "nucl-th": "Physics",
56
+ "physics": "Physics",
57
+ "q-bio": "Quantitative Biology",
58
+ "q-fin": "Quantitative Finance",
59
+ "quant-ph": "Physics",
60
+ "stat": "Statistics",
61
+ }
62
+
63
+
64
+ def normalize_text(value):
65
+ return " ".join(str(value or "").split())
66
+
67
+
68
+ def parse_top_level_label(raw_tag):
69
+ if not raw_tag:
70
+ return None
71
+
72
+ try:
73
+ parsed_tags = ast.literal_eval(str(raw_tag))
74
+ except (SyntaxError, ValueError):
75
+ return None
76
+
77
+ if not isinstance(parsed_tags, list):
78
+ return None
79
+
80
+ for tag in parsed_tags:
81
+ if not isinstance(tag, dict):
82
+ continue
83
+ term = tag.get("term")
84
+ if not term:
85
+ continue
86
+ prefix = str(term).split(".")[0]
87
+ label = PREFIX_TO_LABEL.get(prefix)
88
+ if label:
89
+ return label
90
+
91
+ return None
92
+
93
+
94
+ def build_records():
95
+ with DATA_PATH.open("r", encoding="utf-8") as file:
96
+ raw_records = json.load(file)
97
+
98
+ prepared_records: list[dict[str, str]] = []
99
+ skipped = Counter()
100
+
101
+ for item in raw_records:
102
+ title = normalize_text(item.get(TITLE_FIELD))
103
+ abstract = normalize_text(item.get(ABSTRACT_FIELD))
104
+ label = parse_top_level_label(item.get(TAG_FIELD))
105
+ text = format_input_text(title, abstract)
106
+ prepared_records.append(
107
+ {
108
+ "text": text,
109
+ "label": label,
110
+ }
111
+ )
112
+
113
+ print(f"Loaded {len(prepared_records)}")
114
+
115
+ label_distribution = Counter(record["label"] for record in prepared_records)
116
+ print("Label distribution:", dict(label_distribution))
117
+ return prepared_records
118
+
119
+
120
+ def build_splits(records):
121
+ dataset = Dataset.from_list(records)
122
+ split = dataset.train_test_split(test_size=VALIDATION_SIZE, seed=SEED)
123
+ return DatasetDict(train=split["train"], validation=split["test"])
124
+
125
+
126
+ def preprocess(batch, *, tokenizer, label2id):
127
+ tokenized = tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH)
128
+ tokenized["labels"] = [label2id[label] for label in batch["label"]]
129
+ return tokenized
130
+
131
+
132
+ def compute_metrics(eval_prediction):
133
+ logits, labels = eval_prediction
134
+ predictions = np.argmax(logits, axis=-1)
135
+ return {
136
+ "accuracy": accuracy_score(labels, predictions),
137
+ "macro_f1": f1_score(labels, predictions, average="macro"),
138
+ }
139
+
140
+
141
+ def main() -> None:
142
+ if not DATA_PATH.exists():
143
+ raise FileNotFoundError(f"Dataset file not found: {DATA_PATH}")
144
+
145
+ set_seed(SEED)
146
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
147
+ HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
148
+
149
+ records = build_records()
150
+ raw_splits = build_splits(records)
151
+
152
+ label_names = sorted({record["label"] for record in records})
153
+ label2id = {label: index for index, label in enumerate(label_names)}
154
+ id2label = {index: label for label, index in label2id.items()}
155
+
156
+ tokenizer = AutoTokenizer.from_pretrained(
157
+ BASE_MODEL_NAME,
158
+ cache_dir=HF_CACHE_DIR.as_posix(),
159
+ )
160
+
161
+ tokenized_splits = raw_splits.map(
162
+ partial(preprocess, tokenizer=tokenizer, label2id=label2id),
163
+ batched=True,
164
+ remove_columns=raw_splits["train"].column_names,
165
+ )
166
+
167
+ model = AutoModelForSequenceClassification.from_pretrained(
168
+ BASE_MODEL_NAME,
169
+ cache_dir=HF_CACHE_DIR.as_posix(),
170
+ num_labels=len(label_names),
171
+ id2label=id2label,
172
+ label2id=label2id,
173
+ )
174
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
175
+
176
+ training_args = TrainingArguments(
177
+ output_dir=OUTPUT_DIR.as_posix(),
178
+ do_train=True,
179
+ do_eval=True,
180
+ eval_strategy="epoch",
181
+ save_strategy="epoch",
182
+ logging_strategy="steps",
183
+ logging_steps=LOGGING_STEPS,
184
+ learning_rate=LEARNING_RATE,
185
+ weight_decay=WEIGHT_DECAY,
186
+ num_train_epochs=NUM_TRAIN_EPOCHS,
187
+ per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
188
+ per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
189
+ load_best_model_at_end=True,
190
+ metric_for_best_model="macro_f1",
191
+ greater_is_better=True,
192
+ save_total_limit=2,
193
+ report_to=[],
194
+ seed=SEED,
195
+ )
196
+
197
+ trainer = Trainer(
198
+ model=model,
199
+ args=training_args,
200
+ train_dataset=tokenized_splits["train"],
201
+ eval_dataset=tokenized_splits["validation"],
202
+ processing_class=tokenizer,
203
+ data_collator=data_collator,
204
+ compute_metrics=compute_metrics,
205
+ )
206
+
207
+ trainer.train()
208
+ metrics = trainer.evaluate()
209
+ trainer.save_model(OUTPUT_DIR.as_posix())
210
+ tokenizer.save_pretrained(OUTPUT_DIR.as_posix())
211
+
212
+ summary_path = OUTPUT_DIR / "training_summary.json"
213
+ summary = {
214
+ "base_model": BASE_MODEL_NAME,
215
+ "data_path": DATA_PATH.as_posix(),
216
+ "output_dir": OUTPUT_DIR.as_posix(),
217
+ "title_field": TITLE_FIELD,
218
+ "abstract_field": ABSTRACT_FIELD,
219
+ "tag_field": TAG_FIELD,
220
+ "labels": label_names,
221
+ "metrics": metrics,
222
+ }
223
+ summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
224
+
225
+ print(json.dumps(summary, indent=2))
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()