| import torch |
| from datasets import load_dataset |
| from transformers import ( |
| Trainer, |
| T5Config, |
| T5TokenizerFast, |
| TrainingArguments, |
| DataCollatorForSeq2Seq, |
| T5ForConditionalGeneration |
| ) |
|
|
|
|
| |
| base_model = "t5-small" |
| data_path = "src/data/clean_corpus.jsonl" |
| tokeniser_path = "src/tokeniser/" |
| output_dir = "checkpoints/" |
|
|
| |
| tokeniser = T5TokenizerFast.from_pretrained(tokeniser_path) |
| vocab_size = tokeniser.vocab_size |
| pad_token_id = tokeniser.pad_token_id |
|
|
| |
| config = T5Config( |
| vocab_size = vocab_size, |
| d_model = 512, |
| d_ff = 2048, |
| num_layers = 6, |
| num_heads = 8, |
| pad_token_id = pad_token_id, |
| decoder_start_token_id = pad_token_id |
| ) |
|
|
| model = T5ForConditionalGeneration(config) |
|
|
|
|
| def tokenise_function(example: dict) -> T5TokenizerFast: |
| """ |
| Simple function to tokenise input data. |
| """ |
| inputs = [f"Cyrillic2Latin: {item['src']}" for item in example["transliteration"]] |
| targets = [item["tgt"] for item in example["transliteration"]] |
|
|
| model_inputs = tokeniser( |
| inputs, max_length = 128, truncation = True, padding = "max_length" |
| ) |
| labels = tokeniser( |
| targets, max_length = 128, truncation = True, padding = "max_length" |
| )["input_ids"] |
|
|
| model_inputs["labels"] = labels |
|
|
| return model_inputs |
|
|
|
|
| |
| dataset = load_dataset("json", data_files = data_path, split = "train") |
|
|
| |
| dataset_split = dataset.train_test_split(test_size = 0.25) |
| train_dataset = dataset_split["train"] |
| val_dataset = dataset_split["test"] |
|
|
| |
| tokenised_train = train_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"]) |
| tokenised_eval = val_dataset.map(tokenise_function, batched = True, remove_columns = ["transliteration"]) |
|
|
| |
| data_collator = DataCollatorForSeq2Seq(tokenizer = tokeniser, model = model) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir = output_dir, |
| overwrite_output_dir = True, |
| num_train_epochs = 2, |
| per_device_train_batch_size = 32, |
| gradient_accumulation_steps = 2, |
| save_strategy = "steps", |
| save_steps = 500, |
| save_total_limit = 3, |
| eval_strategy = "epoch", |
| logging_dir = "logs", |
| fp16 = torch.cuda.is_available() |
| ) |
|
|
| |
| trainer = Trainer( |
| model = model, |
| args = training_args, |
| train_dataset = tokenised_train, |
| eval_dataset = tokenised_eval, |
| data_collator = data_collator, |
| processing_class = tokeniser |
| ) |
|
|
| |
| trainer.train() |
| model.save_pretrained(output_dir) |
| tokeniser.save_pretrained(output_dir) |
|
|
| print("DalaT5 training complete.") |