| |
| |
| |
| """ |
| Run script for fine-tuning OlmoE with adapters on specific text domains. |
| Handles argument parsing and configuration. |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| from transformers import ( |
| HfArgumentParser, |
| TrainingArguments, |
| ) |
|
|
|
|
| @dataclass |
| class ScriptArguments: |
| """ |
| Arguments for the run script that aren't covered by TrainingArguments. |
| """ |
| model_path: str = field( |
| default="allenai/OLMo-7B-Instruct", |
| metadata={"help": "Path to the model to fine-tune"} |
| ) |
| output_dir: str = field( |
| default="./output_olmoe_adapter", |
| metadata={"help": "Directory to save the model and logs"} |
| ) |
| adapter_size: int = field( |
| default=64, |
| metadata={"help": "Size of the adapter layers"} |
| ) |
| dataset_name: str = field( |
| default="mlfoundations/dclm-baseline-1.0", |
| metadata={"help": "Name of the dataset to use"} |
| ) |
| max_steps: int = field( |
| default=10000, |
| metadata={"help": "Maximum number of training steps"} |
| ) |
| learning_rate: float = field( |
| default=5e-5, |
| metadata={"help": "Learning rate for fine-tuning"} |
| ) |
| per_device_batch_size: int = field( |
| default=8, |
| metadata={"help": "Batch size per device"} |
| ) |
| gradient_accumulation_steps: int = field( |
| default=1, |
| metadata={"help": "Number of steps to accumulate gradients"} |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def main(): |
| |
| parser = HfArgumentParser(ScriptArguments) |
| args = parser.parse_args_into_dataclasses()[0] |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| cmd = [ |
| "python", |
| "train_olmoe_adapter.py", |
| |
| |
| f"--model_name_or_path={args.model_path}", |
| f"--adapter_size={args.adapter_size}", |
| "--freeze_base_model=True", |
| f"--checkpoint_dir={args.output_dir}", |
| |
| |
| f"--dataset_name={args.dataset_name}", |
| "--streaming=True", |
| "--streaming_buffer_size=8192", |
| "--max_seq_length=1024", |
| |
| |
| f"--output_dir={args.output_dir}", |
| f"--per_device_train_batch_size={args.per_device_batch_size}", |
| f"--gradient_accumulation_steps={args.gradient_accumulation_steps}", |
| f"--learning_rate={args.learning_rate}", |
| f"--max_steps={args.max_steps}", |
| "--warmup_steps=500", |
| "--logging_steps=10", |
| "--save_steps=1000", |
| "--save_total_limit=2", |
| "--dataloader_num_workers=4", |
| "--seed=42", |
| ] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cmd_str = " ".join(cmd) |
| print(f"Running command: {cmd_str}") |
| |
| |
| os.environ["PYTHONPATH"] = os.getcwd() |
| ret = os.system(cmd_str) |
| |
| if ret != 0: |
| print(f"Training failed with exit code {ret}") |
| sys.exit(ret) |
| |
| print("Training completed successfully!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |