from __future__ import annotations import ast import json from collections import Counter from functools import partial from pathlib import Path import numpy as np from datasets import Dataset, DatasetDict from sklearn.metrics import accuracy_score, f1_score from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, set_seed, ) from paper_classifier import BASE_MODEL_NAME, DEFAULT_MODEL_DIR, MAX_LENGTH, format_input_text DATA_PATH = Path("arxivData.json") OUTPUT_DIR = Path(DEFAULT_MODEL_DIR) HF_CACHE_DIR = Path("/tmp/huggingface") TITLE_FIELD = "title" ABSTRACT_FIELD = "summary" TAG_FIELD = "tag" VALIDATION_SIZE = 0.1 NUM_TRAIN_EPOCHS = 4 LEARNING_RATE = 2e-5 WEIGHT_DECAY = 0.01 PER_DEVICE_TRAIN_BATCH_SIZE = 16 PER_DEVICE_EVAL_BATCH_SIZE = 32 LOGGING_STEPS = 50 SEED = 42 PREFIX_TO_LABEL = { "adap-org": "Quantitative Biology", "astro-ph": "Physics", "cmp-lg": "Computer Science", "cond-mat": "Physics", "cs": "Computer Science", "econ": "Economics", "eess": "Electrical Engineering and Systems Science", "gr-qc": "Physics", "hep-ex": "Physics", "hep-lat": "Physics", "hep-ph": "Physics", "hep-th": "Physics", "math": "Mathematics", "nlin": "Physics", "nucl-th": "Physics", "physics": "Physics", "q-bio": "Quantitative Biology", "q-fin": "Quantitative Finance", "quant-ph": "Physics", "stat": "Statistics", } def normalize_text(value): return " ".join(str(value or "").split()) def parse_top_level_label(raw_tag): if not raw_tag: return None try: parsed_tags = ast.literal_eval(str(raw_tag)) except (SyntaxError, ValueError): return None if not isinstance(parsed_tags, list): return None for tag in parsed_tags: if not isinstance(tag, dict): continue term = tag.get("term") if not term: continue prefix = str(term).split(".")[0] label = PREFIX_TO_LABEL.get(prefix) if label: return label return None def build_records(): with DATA_PATH.open("r", encoding="utf-8") as file: raw_records = json.load(file) prepared_records: list[dict[str, str]] = [] skipped = Counter() for item in raw_records: title = normalize_text(item.get(TITLE_FIELD)) abstract = normalize_text(item.get(ABSTRACT_FIELD)) label = parse_top_level_label(item.get(TAG_FIELD)) text = format_input_text(title, abstract) prepared_records.append( { "text": text, "label": label, } ) print(f"Loaded {len(prepared_records)}") label_distribution = Counter(record["label"] for record in prepared_records) print("Label distribution:", dict(label_distribution)) return prepared_records def build_splits(records): dataset = Dataset.from_list(records) split = dataset.train_test_split(test_size=VALIDATION_SIZE, seed=SEED) return DatasetDict(train=split["train"], validation=split["test"]) def preprocess(batch, *, tokenizer, label2id): tokenized = tokenizer(batch["text"], truncation=True, max_length=MAX_LENGTH) tokenized["labels"] = [label2id[label] for label in batch["label"]] return tokenized def compute_metrics(eval_prediction): logits, labels = eval_prediction predictions = np.argmax(logits, axis=-1) return { "accuracy": accuracy_score(labels, predictions), "macro_f1": f1_score(labels, predictions, average="macro"), } def main() -> None: if not DATA_PATH.exists(): raise FileNotFoundError(f"Dataset file not found: {DATA_PATH}") set_seed(SEED) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) records = build_records() raw_splits = build_splits(records) label_names = sorted({record["label"] for record in records}) label2id = {label: index for index, label in enumerate(label_names)} id2label = {index: label for label, index in label2id.items()} tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL_NAME, cache_dir=HF_CACHE_DIR.as_posix(), ) tokenized_splits = raw_splits.map( partial(preprocess, tokenizer=tokenizer, label2id=label2id), batched=True, remove_columns=raw_splits["train"].column_names, ) model = AutoModelForSequenceClassification.from_pretrained( BASE_MODEL_NAME, cache_dir=HF_CACHE_DIR.as_posix(), num_labels=len(label_names), id2label=id2label, label2id=label2id, ) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) training_args = TrainingArguments( output_dir=OUTPUT_DIR.as_posix(), do_train=True, do_eval=True, eval_strategy="epoch", save_strategy="epoch", logging_strategy="steps", logging_steps=LOGGING_STEPS, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY, num_train_epochs=NUM_TRAIN_EPOCHS, per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE, per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE, load_best_model_at_end=True, metric_for_best_model="macro_f1", greater_is_better=True, save_total_limit=2, report_to=[], seed=SEED, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_splits["train"], eval_dataset=tokenized_splits["validation"], processing_class=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, ) trainer.train() metrics = trainer.evaluate() trainer.save_model(OUTPUT_DIR.as_posix()) tokenizer.save_pretrained(OUTPUT_DIR.as_posix()) summary_path = OUTPUT_DIR / "training_summary.json" summary = { "base_model": BASE_MODEL_NAME, "data_path": DATA_PATH.as_posix(), "output_dir": OUTPUT_DIR.as_posix(), "title_field": TITLE_FIELD, "abstract_field": ABSTRACT_FIELD, "tag_field": TAG_FIELD, "labels": label_names, "metrics": metrics, } summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") print(json.dumps(summary, indent=2)) if __name__ == "__main__": main()