| |
| """Fine-tune rubai-corrector-base on JSONL correction pairs.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import random |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from transformers import ( |
| AutoModelForSeq2SeqLM, |
| AutoTokenizer, |
| DataCollatorForSeq2Seq, |
| Seq2SeqTrainer, |
| Seq2SeqTrainingArguments, |
| get_linear_schedule_with_warmup, |
| set_seed, |
| ) |
| from transformers.optimization import Adafactor |
|
|
|
|
| INPUT_PREFIX = "correct: " |
|
|
|
|
| def resolve_jsonl_path(path: Path) -> Path: |
| if path.is_dir(): |
| candidate = path / "data.jsonl" |
| if candidate.exists(): |
| return candidate |
| raise FileNotFoundError(f"Directory {path} does not contain data.jsonl") |
| return path |
|
|
|
|
| def load_records(path: Path) -> list[dict[str, Any]]: |
| data_path = resolve_jsonl_path(path) |
| records: list[dict[str, Any]] = [] |
| with data_path.open("r", encoding="utf-8") as handle: |
| for line_num, line in enumerate(handle, start=1): |
| line = line.strip() |
| if not line: |
| continue |
| record = json.loads(line) |
| if not isinstance(record.get("input"), str) or not isinstance(record.get("output"), str): |
| raise ValueError( |
| f"{data_path}:{line_num} must contain string fields 'input' and 'output'" |
| ) |
| records.append(record) |
| if not records: |
| raise ValueError(f"No records loaded from {data_path}") |
| return records |
|
|
|
|
| def split_records( |
| records: list[dict[str, Any]], |
| validation_split: float, |
| seed: int, |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: |
| if validation_split <= 0: |
| return records, [] |
| if not 0 < validation_split < 1: |
| raise ValueError("--validation-split must be between 0 and 1") |
| items = records[:] |
| random.Random(seed).shuffle(items) |
| eval_size = max(1, int(len(items) * validation_split)) |
| return items[eval_size:], items[:eval_size] |
|
|
|
|
| class CorrectionDataset(Dataset): |
| def __init__( |
| self, |
| records: list[dict[str, Any]], |
| tokenizer, |
| max_source_length: int, |
| max_target_length: int, |
| ): |
| self.records = records |
| self.tokenizer = tokenizer |
| self.max_source_length = max_source_length |
| self.max_target_length = max_target_length |
|
|
| def __len__(self) -> int: |
| return len(self.records) |
|
|
| def __getitem__(self, index: int) -> dict[str, Any]: |
| record = self.records[index] |
| model_inputs = self.tokenizer( |
| INPUT_PREFIX + record["input"], |
| truncation=True, |
| max_length=self.max_source_length, |
| ) |
| labels = self.tokenizer( |
| record["output"], |
| truncation=True, |
| max_length=self.max_target_length, |
| ) |
| model_inputs["labels"] = labels["input_ids"] |
| return model_inputs |
|
|
|
|
| class AdafactorSeq2SeqTrainer(Seq2SeqTrainer): |
| def create_optimizer(self): |
| if self.optimizer is None: |
| self.optimizer = Adafactor( |
| self.model.parameters(), |
| lr=self.args.learning_rate, |
| scale_parameter=False, |
| relative_step=False, |
| warmup_init=False, |
| weight_decay=self.args.weight_decay, |
| ) |
| return self.optimizer |
|
|
| def create_scheduler(self, num_training_steps: int, optimizer=None): |
| if self.lr_scheduler is None: |
| actual_optimizer = optimizer if optimizer is not None else self.optimizer |
| self.lr_scheduler = get_linear_schedule_with_warmup( |
| actual_optimizer, |
| num_warmup_steps=self.args.get_warmup_steps(num_training_steps), |
| num_training_steps=num_training_steps, |
| ) |
| return self.lr_scheduler |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument("--model-path", type=Path, default=Path(__file__).resolve().parent) |
| parser.add_argument("--train-file", type=Path, required=True) |
| parser.add_argument("--eval-file", type=Path, default=None) |
| parser.add_argument("--validation-split", type=float, default=0.0) |
| parser.add_argument("--output-dir", type=Path, required=True) |
| parser.add_argument("--max-source-length", type=int, default=512) |
| parser.add_argument("--max-target-length", type=int, default=512) |
| parser.add_argument("--learning-rate", type=float, default=5e-5) |
| parser.add_argument("--weight-decay", type=float, default=0.01) |
| parser.add_argument("--warmup-ratio", type=float, default=0.1) |
| parser.add_argument("--num-train-epochs", type=float, default=2.0) |
| parser.add_argument("--per-device-train-batch-size", type=int, default=16) |
| parser.add_argument("--per-device-eval-batch-size", type=int, default=16) |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=4) |
| parser.add_argument("--save-steps", type=int, default=500) |
| parser.add_argument("--eval-steps", type=int, default=500) |
| parser.add_argument("--logging-steps", type=int, default=50) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--bf16", action="store_true") |
| parser.add_argument("--fp16", action="store_true") |
| parser.add_argument("--gradient-checkpointing", action="store_true", default=True) |
| parser.add_argument("--no-gradient-checkpointing", action="store_true") |
| parser.add_argument("--resume-from-checkpoint", type=str, default=None) |
| return parser.parse_args() |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| set_seed(args.seed) |
|
|
| gradient_checkpointing = args.gradient_checkpointing and not args.no_gradient_checkpointing |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path) |
| if gradient_checkpointing: |
| model.gradient_checkpointing_enable() |
|
|
| train_records = load_records(args.train_file) |
| if args.eval_file is not None: |
| eval_records = load_records(args.eval_file) |
| else: |
| train_records, eval_records = split_records(train_records, args.validation_split, args.seed) |
|
|
| train_dataset = CorrectionDataset( |
| train_records, |
| tokenizer, |
| max_source_length=args.max_source_length, |
| max_target_length=args.max_target_length, |
| ) |
| eval_dataset = None |
| if eval_records: |
| eval_dataset = CorrectionDataset( |
| eval_records, |
| tokenizer, |
| max_source_length=args.max_source_length, |
| max_target_length=args.max_target_length, |
| ) |
|
|
| data_collator = DataCollatorForSeq2Seq( |
| tokenizer=tokenizer, |
| model=model, |
| label_pad_token_id=-100, |
| pad_to_multiple_of=8 if torch.cuda.is_available() else None, |
| ) |
|
|
| training_args = Seq2SeqTrainingArguments( |
| output_dir=str(args.output_dir), |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| warmup_ratio=args.warmup_ratio, |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| per_device_eval_batch_size=args.per_device_eval_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| eval_steps=args.eval_steps, |
| evaluation_strategy="steps" if eval_dataset is not None else "no", |
| save_strategy="steps", |
| save_total_limit=2, |
| predict_with_generate=False, |
| report_to=[], |
| bf16=args.bf16, |
| fp16=args.fp16 and not args.bf16, |
| gradient_checkpointing=gradient_checkpointing, |
| dataloader_num_workers=2, |
| remove_unused_columns=False, |
| seed=args.seed, |
| ) |
|
|
| trainer = AdafactorSeq2SeqTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| tokenizer=tokenizer, |
| data_collator=data_collator, |
| ) |
|
|
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|