from __future__ import annotations import argparse from dataclasses import asdict from pathlib import Path try: import torch from torch.utils.data import Dataset from transformers import ( AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, ) except Exception as exc: # pragma: no cover raise SystemExit( "Training requires torch, transformers and accelerate. Install dependencies first." ) from exc from data_utils import load_jsonl class JsonlSeq2SeqDataset(Dataset): def __init__(self, path, tokenizer, max_source_length: int, max_target_length: int): self.examples = load_jsonl(path) self.tokenizer = tokenizer self.max_source_length = max_source_length self.max_target_length = max_target_length def __len__(self): return len(self.examples) def __getitem__(self, idx): ex = self.examples[idx] model_inputs = self.tokenizer( ex.article, max_length=self.max_source_length, truncation=True, ) labels = self.tokenizer( text_target=ex.summary, max_length=self.max_target_length, truncation=True, ) model_inputs["labels"] = labels["input_ids"] return model_inputs def parse_args(): parser = argparse.ArgumentParser(description="Fine-tune a Chinese seq2seq summarization model") parser.add_argument("--train-path", required=True) parser.add_argument("--valid-path", default=None) parser.add_argument("--output-dir", required=True) parser.add_argument("--model-name", default="fnlp/bart-base-chinese") parser.add_argument("--max-source-length", type=int, default=512) parser.add_argument("--max-target-length", type=int, default=128) parser.add_argument("--num-train-epochs", type=float, default=3.0) parser.add_argument("--train-batch-size", type=int, default=2) parser.add_argument("--eval-batch-size", type=int, default=2) parser.add_argument("--learning-rate", type=float, default=3e-5) parser.add_argument("--logging-steps", type=int, default=25) parser.add_argument("--save-steps", type=int, default=200) return parser.parse_args() def main(): args = parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) train_dataset = JsonlSeq2SeqDataset( args.train_path, tokenizer, max_source_length=args.max_source_length, max_target_length=args.max_target_length, ) eval_dataset = ( JsonlSeq2SeqDataset( args.valid_path, tokenizer, max_source_length=args.max_source_length, max_target_length=args.max_target_length, ) if args.valid_path else None ) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) training_args = Seq2SeqTrainingArguments( output_dir=str(output_dir), learning_rate=args.learning_rate, per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, predict_with_generate=True, num_train_epochs=args.num_train_epochs, logging_steps=args.logging_steps, save_steps=args.save_steps, save_total_limit=2, evaluation_strategy="steps" if eval_dataset else "no", eval_steps=args.save_steps if eval_dataset else None, fp16=torch.cuda.is_available(), report_to=[], ) trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator, ) trainer.train() trainer.save_model(str(output_dir)) tokenizer.save_pretrained(str(output_dir)) print(f"saved to {output_dir}") if __name__ == "__main__": main()