Spaces:
Running
Running
czjun
Update README and implement training and evaluation scripts for Chinese summarization model
8d28a45 | from __future__ import annotations | |
| import argparse | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| try: | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| DataCollatorForSeq2Seq, | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| ) | |
| except Exception as exc: # pragma: no cover | |
| raise SystemExit( | |
| "Training requires torch, transformers and accelerate. Install dependencies first." | |
| ) from exc | |
| from data_utils import load_jsonl | |
| class JsonlSeq2SeqDataset(Dataset): | |
| def __init__(self, path, tokenizer, max_source_length: int, max_target_length: int): | |
| self.examples = load_jsonl(path) | |
| self.tokenizer = tokenizer | |
| self.max_source_length = max_source_length | |
| self.max_target_length = max_target_length | |
| def __len__(self): | |
| return len(self.examples) | |
| def __getitem__(self, idx): | |
| ex = self.examples[idx] | |
| model_inputs = self.tokenizer( | |
| ex.article, | |
| max_length=self.max_source_length, | |
| truncation=True, | |
| ) | |
| labels = self.tokenizer( | |
| text_target=ex.summary, | |
| max_length=self.max_target_length, | |
| truncation=True, | |
| ) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Fine-tune a Chinese seq2seq summarization model") | |
| parser.add_argument("--train-path", required=True) | |
| parser.add_argument("--valid-path", default=None) | |
| parser.add_argument("--output-dir", required=True) | |
| parser.add_argument("--model-name", default="fnlp/bart-base-chinese") | |
| parser.add_argument("--max-source-length", type=int, default=512) | |
| parser.add_argument("--max-target-length", type=int, default=128) | |
| parser.add_argument("--num-train-epochs", type=float, default=3.0) | |
| parser.add_argument("--train-batch-size", type=int, default=2) | |
| parser.add_argument("--eval-batch-size", type=int, default=2) | |
| parser.add_argument("--learning-rate", type=float, default=3e-5) | |
| parser.add_argument("--logging-steps", type=int, default=25) | |
| parser.add_argument("--save-steps", type=int, default=200) | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) | |
| train_dataset = JsonlSeq2SeqDataset( | |
| args.train_path, | |
| tokenizer, | |
| max_source_length=args.max_source_length, | |
| max_target_length=args.max_target_length, | |
| ) | |
| eval_dataset = ( | |
| JsonlSeq2SeqDataset( | |
| args.valid_path, | |
| tokenizer, | |
| max_source_length=args.max_source_length, | |
| max_target_length=args.max_target_length, | |
| ) | |
| if args.valid_path | |
| else None | |
| ) | |
| data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir=str(output_dir), | |
| learning_rate=args.learning_rate, | |
| per_device_train_batch_size=args.train_batch_size, | |
| per_device_eval_batch_size=args.eval_batch_size, | |
| predict_with_generate=True, | |
| num_train_epochs=args.num_train_epochs, | |
| logging_steps=args.logging_steps, | |
| save_steps=args.save_steps, | |
| save_total_limit=2, | |
| evaluation_strategy="steps" if eval_dataset else "no", | |
| eval_steps=args.save_steps if eval_dataset else None, | |
| fp16=torch.cuda.is_available(), | |
| report_to=[], | |
| ) | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| trainer.train() | |
| trainer.save_model(str(output_dir)) | |
| tokenizer.save_pretrained(str(output_dir)) | |
| print(f"saved to {output_dir}") | |
| if __name__ == "__main__": | |
| main() | |