SHAD_Homework / src /train_distilbert.py
AndreyForty's picture
Upload 2 files
af53d97 verified
from __future__ import annotations
import ast
import json
from collections import Counter
from functools import partial
from pathlib import Path
import numpy as np
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, f1_score
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
set_seed,
)
from paper_classifier import BASE_MODEL_NAME, DEFAULT_MODEL_DIR, MAX_LENGTH, format_input_text
DATA_PATH = Path("arxivData.json")
OUTPUT_DIR = Path(DEFAULT_MODEL_DIR)
HF_CACHE_DIR = Path("/tmp/huggingface")
TITLE_FIELD = "title"
ABSTRACT_FIELD = "summary"
TAG_FIELD = "tag"
VALIDATION_SIZE = 0.1
NUM_TRAIN_EPOCHS = 4
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
PER_DEVICE_TRAIN_BATCH_SIZE = 16
PER_DEVICE_EVAL_BATCH_SIZE = 32
LOGGING_STEPS = 50
SEED = 42
PREFIX_TO_LABEL = {
"adap-org": "Quantitative Biology",
"astro-ph": "Physics",
"cmp-lg": "Computer Science",
"cond-mat": "Physics",
"cs": "Computer Science",
"econ": "Economics",
"eess": "Electrical Engineering and Systems Science",
"gr-qc": "Physics",
"hep-ex": "Physics",
"hep-lat": "Physics",
"hep-ph": "Physics",
"hep-th": "Physics",
"math": "Mathematics",
"nlin": "Physics",
"nucl-th": "Physics",
"physics": "Physics",
"q-bio": "Quantitative Biology",
"q-fin": "Quantitative Finance",
"quant-ph": "Physics",
"stat": "Statistics",
}
def normalize_text(value):
return " ".join(str(value or "").split())
def parse_top_level_label(raw_tag):
if not raw_tag:
return None
try:
parsed_tags = ast.literal_eval(str(raw_tag))
except (SyntaxError, ValueError):
return None
if not isinstance(parsed_tags, list):
return None
for tag in parsed_tags:
if not isinstance(tag, dict):
continue
term = tag.get("term")
if not term:
continue
prefix = str(term).split(".")[0]
label = PREFIX_TO_LABEL.get(prefix)
if label:
return label
return None
def build_records():
with DATA_PATH.open("r", encoding="utf-8") as file:
raw_records = json.load(file)
prepared_records: list[dict[str, str]] = []
skipped = Counter()
for item in raw_records:
title = normalize_text(item.get(TITLE_FIELD))
abstract = normalize_text(item.get(ABSTRACT_FIELD))
label = parse_top_level_label(item.get(TAG_FIELD))
text = format_input_text(title, abstract)
prepared_records.append(
{
"text": text,
"label": label,
}
)
print(f"Loaded {len(prepared_records)}")
label_distribution = Counter(record["label"] for record in prepared_records)
print("Label distribution:", dict(label_distribution))
return prepared_records
def build_splits(records):
dataset = Dataset.from_list(records)
split = dataset.train_test_split(test_size=VALIDATION_SIZE, seed=SEED)
return DatasetDict(train=split["train"], validation=split["test"])
def preprocess(batch, *, tokenizer, label2id):
tokenized = tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH)
tokenized["labels"] = [label2id[label] for label in batch["label"]]
return tokenized
def compute_metrics(eval_prediction):
logits, labels = eval_prediction
predictions = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy_score(labels, predictions),
"macro_f1": f1_score(labels, predictions, average="macro"),
}
def main() -> None:
if not DATA_PATH.exists():
raise FileNotFoundError(f"Dataset file not found: {DATA_PATH}")
set_seed(SEED)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
records = build_records()
raw_splits = build_splits(records)
label_names = sorted({record["label"] for record in records})
label2id = {label: index for index, label in enumerate(label_names)}
id2label = {index: label for label, index in label2id.items()}
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_NAME,
cache_dir=HF_CACHE_DIR.as_posix(),
)
tokenized_splits = raw_splits.map(
partial(preprocess, tokenizer=tokenizer, label2id=label2id),
batched=True,
remove_columns=raw_splits["train"].column_names,
)
model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL_NAME,
cache_dir=HF_CACHE_DIR.as_posix(),
num_labels=len(label_names),
id2label=id2label,
label2id=label2id,
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir=OUTPUT_DIR.as_posix(),
do_train=True,
do_eval=True,
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=LOGGING_STEPS,
learning_rate=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
num_train_epochs=NUM_TRAIN_EPOCHS,
per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
load_best_model_at_end=True,
metric_for_best_model="macro_f1",
greater_is_better=True,
save_total_limit=2,
report_to=[],
seed=SEED,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_splits["train"],
eval_dataset=tokenized_splits["validation"],
processing_class=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
metrics = trainer.evaluate()
trainer.save_model(OUTPUT_DIR.as_posix())
tokenizer.save_pretrained(OUTPUT_DIR.as_posix())
summary_path = OUTPUT_DIR / "training_summary.json"
summary = {
"base_model": BASE_MODEL_NAME,
"data_path": DATA_PATH.as_posix(),
"output_dir": OUTPUT_DIR.as_posix(),
"title_field": TITLE_FIELD,
"abstract_field": ABSTRACT_FIELD,
"tag_field": TAG_FIELD,
"labels": label_names,
"metrics": metrics,
}
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()