| import sys |
| import torch |
| sys.path.append("..") |
|
|
| import os |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling |
| from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \ |
| GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file |
| import argparse |
|
|
| |
|
|
| |
| |
|
|
| if __name__ == "__main__": |
|
|
| |
| parser = argparse.ArgumentParser(description="Training configuration.") |
|
|
| parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.') |
| parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.') |
| parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training.') |
| parser.add_argument('--epoch', type=int, default=20, help='train epoch') |
| parser.add_argument('--seed', type=int, default=0, help='Random seed.') |
|
|
| args = parser.parse_args() |
|
|
| |
| ckpt_path = "./checkpoints" |
| |
|
|
| model_name = "meta-llama/Llama-3.2-3B" |
|
|
| model_save_name = "Llama-3.2-3B" |
| |
| run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" |
| cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts") |
| run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs") |
| os.makedirs(cache_dir, exist_ok=True) |
| os.makedirs(run_dir, exist_ok=True) |
|
|
| |
| dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" |
| dataset = load_dataset('babylm_dataset_llama.py', name=dataset_name, trust_remote_code=True) |
| train_dataset = dataset['train'] |
|
|
| |
| |
| |
| |
| tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer'] |
| model = AutoModelForCausalLM.from_pretrained(model_name, |
| device_map="auto", |
| cache_dir=cache_dir) |
|
|
| |
| |
| def tokenize_function(examples): |
| return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024) |
| tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) |
|
|
| |
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=run_dir, |
| |
| evaluation_strategy="no", |
| |
| per_device_train_batch_size=args.batch_size, |
| logging_dir='./logs', |
| logging_steps=1000, |
| save_steps=1000, |
| |
| learning_rate=2e-5, |
| num_train_epochs=args.epoch, |
| seed=args.seed, |
| |
| gradient_accumulation_steps=1, |
| fp16 = True, |
| report_to="none", |
| ) |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_train, |
| tokenizer=tokenizer, |
| data_collator=data_collator |
| ) |
| |
| |
| trainer.train() |
| |
| |