Transformer / train.py
czjun
Update README and implement training and evaluation scripts for Chinese summarization model
8d28a45
from __future__ import annotations
import argparse
from dataclasses import asdict
from pathlib import Path
try:
import torch
from torch.utils.data import Dataset
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
except Exception as exc: # pragma: no cover
raise SystemExit(
"Training requires torch, transformers and accelerate. Install dependencies first."
) from exc
from data_utils import load_jsonl
class JsonlSeq2SeqDataset(Dataset):
def __init__(self, path, tokenizer, max_source_length: int, max_target_length: int):
self.examples = load_jsonl(path)
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
model_inputs = self.tokenizer(
ex.article,
max_length=self.max_source_length,
truncation=True,
)
labels = self.tokenizer(
text_target=ex.summary,
max_length=self.max_target_length,
truncation=True,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune a Chinese seq2seq summarization model")
parser.add_argument("--train-path", required=True)
parser.add_argument("--valid-path", default=None)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--model-name", default="fnlp/bart-base-chinese")
parser.add_argument("--max-source-length", type=int, default=512)
parser.add_argument("--max-target-length", type=int, default=128)
parser.add_argument("--num-train-epochs", type=float, default=3.0)
parser.add_argument("--train-batch-size", type=int, default=2)
parser.add_argument("--eval-batch-size", type=int, default=2)
parser.add_argument("--learning-rate", type=float, default=3e-5)
parser.add_argument("--logging-steps", type=int, default=25)
parser.add_argument("--save-steps", type=int, default=200)
return parser.parse_args()
def main():
args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
train_dataset = JsonlSeq2SeqDataset(
args.train_path,
tokenizer,
max_source_length=args.max_source_length,
max_target_length=args.max_target_length,
)
eval_dataset = (
JsonlSeq2SeqDataset(
args.valid_path,
tokenizer,
max_source_length=args.max_source_length,
max_target_length=args.max_target_length,
)
if args.valid_path
else None
)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
training_args = Seq2SeqTrainingArguments(
output_dir=str(output_dir),
learning_rate=args.learning_rate,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
predict_with_generate=True,
num_train_epochs=args.num_train_epochs,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_total_limit=2,
evaluation_strategy="steps" if eval_dataset else "no",
eval_steps=args.save_steps if eval_dataset else None,
fp16=torch.cuda.is_available(),
report_to=[],
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
print(f"saved to {output_dir}")
if __name__ == "__main__":
main()