rubai-corrector-base / finetune.py
islomov's picture
Initial private upload
bfe896d verified
#!/usr/bin/env python3
"""Fine-tune rubai-corrector-base on JSONL correction pairs."""
from __future__ import annotations
import argparse
import json
import random
from pathlib import Path
from typing import Any
import torch
from torch.utils.data import Dataset
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
get_linear_schedule_with_warmup,
set_seed,
)
from transformers.optimization import Adafactor
INPUT_PREFIX = "correct: "
def resolve_jsonl_path(path: Path) -> Path:
if path.is_dir():
candidate = path / "data.jsonl"
if candidate.exists():
return candidate
raise FileNotFoundError(f"Directory {path} does not contain data.jsonl")
return path
def load_records(path: Path) -> list[dict[str, Any]]:
data_path = resolve_jsonl_path(path)
records: list[dict[str, Any]] = []
with data_path.open("r", encoding="utf-8") as handle:
for line_num, line in enumerate(handle, start=1):
line = line.strip()
if not line:
continue
record = json.loads(line)
if not isinstance(record.get("input"), str) or not isinstance(record.get("output"), str):
raise ValueError(
f"{data_path}:{line_num} must contain string fields 'input' and 'output'"
)
records.append(record)
if not records:
raise ValueError(f"No records loaded from {data_path}")
return records
def split_records(
records: list[dict[str, Any]],
validation_split: float,
seed: int,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
if validation_split <= 0:
return records, []
if not 0 < validation_split < 1:
raise ValueError("--validation-split must be between 0 and 1")
items = records[:]
random.Random(seed).shuffle(items)
eval_size = max(1, int(len(items) * validation_split))
return items[eval_size:], items[:eval_size]
class CorrectionDataset(Dataset):
def __init__(
self,
records: list[dict[str, Any]],
tokenizer,
max_source_length: int,
max_target_length: int,
):
self.records = records
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, index: int) -> dict[str, Any]:
record = self.records[index]
model_inputs = self.tokenizer(
INPUT_PREFIX + record["input"],
truncation=True,
max_length=self.max_source_length,
)
labels = self.tokenizer(
record["output"],
truncation=True,
max_length=self.max_target_length,
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
class AdafactorSeq2SeqTrainer(Seq2SeqTrainer):
def create_optimizer(self):
if self.optimizer is None:
self.optimizer = Adafactor(
self.model.parameters(),
lr=self.args.learning_rate,
scale_parameter=False,
relative_step=False,
warmup_init=False,
weight_decay=self.args.weight_decay,
)
return self.optimizer
def create_scheduler(self, num_training_steps: int, optimizer=None):
if self.lr_scheduler is None:
actual_optimizer = optimizer if optimizer is not None else self.optimizer
self.lr_scheduler = get_linear_schedule_with_warmup(
actual_optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
return self.lr_scheduler
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model-path", type=Path, default=Path(__file__).resolve().parent)
parser.add_argument("--train-file", type=Path, required=True)
parser.add_argument("--eval-file", type=Path, default=None)
parser.add_argument("--validation-split", type=float, default=0.0)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--max-source-length", type=int, default=512)
parser.add_argument("--max-target-length", type=int, default=512)
parser.add_argument("--learning-rate", type=float, default=5e-5)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--warmup-ratio", type=float, default=0.1)
parser.add_argument("--num-train-epochs", type=float, default=2.0)
parser.add_argument("--per-device-train-batch-size", type=int, default=16)
parser.add_argument("--per-device-eval-batch-size", type=int, default=16)
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
parser.add_argument("--save-steps", type=int, default=500)
parser.add_argument("--eval-steps", type=int, default=500)
parser.add_argument("--logging-steps", type=int, default=50)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--gradient-checkpointing", action="store_true", default=True)
parser.add_argument("--no-gradient-checkpointing", action="store_true")
parser.add_argument("--resume-from-checkpoint", type=str, default=None)
return parser.parse_args()
def main() -> int:
args = parse_args()
set_seed(args.seed)
gradient_checkpointing = args.gradient_checkpointing and not args.no_gradient_checkpointing
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
train_records = load_records(args.train_file)
if args.eval_file is not None:
eval_records = load_records(args.eval_file)
else:
train_records, eval_records = split_records(train_records, args.validation_split, args.seed)
train_dataset = CorrectionDataset(
train_records,
tokenizer,
max_source_length=args.max_source_length,
max_target_length=args.max_target_length,
)
eval_dataset = None
if eval_records:
eval_dataset = CorrectionDataset(
eval_records,
tokenizer,
max_source_length=args.max_source_length,
max_target_length=args.max_target_length,
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=8 if torch.cuda.is_available() else None,
)
training_args = Seq2SeqTrainingArguments(
output_dir=str(args.output_dir),
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
evaluation_strategy="steps" if eval_dataset is not None else "no",
save_strategy="steps",
save_total_limit=2,
predict_with_generate=False,
report_to=[],
bf16=args.bf16,
fp16=args.fp16 and not args.bf16,
gradient_checkpointing=gradient_checkpointing,
dataloader_num_workers=2,
remove_unused_columns=False,
seed=args.seed,
)
trainer = AdafactorSeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
return 0
if __name__ == "__main__":
raise SystemExit(main())