| |
| |
| """ |
| Training script for OlmoE model with adapters on the mlfoundations/dclm-baseline-1.0 dataset. |
| This script demonstrates parameter-efficient fine-tuning using adapters. |
| """ |
|
|
| import os |
| import math |
| import logging |
| import argparse |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Tuple, Any, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, IterableDataset |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import LambdaLR |
|
|
| from datasets import load_dataset |
| from transformers import ( |
| OlmoConfig, |
| OlmoForCausalLM, |
| AutoTokenizer, |
| DataCollatorForLanguageModeling, |
| HfArgumentParser, |
| TrainingArguments, |
| set_seed, |
| get_scheduler, |
| ) |
| from tqdm import tqdm |
| from accelerate import Accelerator, DistributedType |
| from accelerate.utils import find_batch_size |
|
|
| from modeling_olmoe import ( |
| OlmoEWithAdaptersForCausalLM, |
| OlmoEForCausalLM, |
| ) |
|
|
| |
| logger = logging.getLogger(__name__) |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
|
|
| @dataclass |
| class ModelArguments: |
| """Arguments for model configuration.""" |
| model_name_or_path: str = field( |
| default="allenai/OLMo-7B-Instruct", |
| metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
| ) |
| adapter_size: int = field( |
| default=64, |
| metadata={"help": "Size of the adapter layers"} |
| ) |
| freeze_base_model: bool = field( |
| default=True, |
| metadata={"help": "Whether to freeze all parameters except the adapters"} |
| ) |
| checkpoint_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": "Path to save model checkpoints"} |
| ) |
|
|
|
|
| @dataclass |
| class DataArguments: |
| """Arguments for dataset configuration.""" |
| dataset_name: str = field( |
| default="mlfoundations/dclm-baseline-1.0", |
| metadata={"help": "Dataset name or path for training"} |
| ) |
| streaming: bool = field( |
| default=True, |
| metadata={"help": "Whether to stream the dataset"} |
| ) |
| streaming_buffer_size: int = field( |
| default=8192, |
| metadata={"help": "Buffer size for streaming dataset"} |
| ) |
| max_seq_length: int = field( |
| default=1024, |
| metadata={"help": "Maximum sequence length for training"} |
| ) |
| preprocessing_num_workers: Optional[int] = field( |
| default=None, |
| metadata={"help": "Number of workers for preprocessing"} |
| ) |
| text_column_name: str = field( |
| default="text", |
| metadata={"help": "Column name for text data"} |
| ) |
|
|
|
|
| class StreamingTextDataset(IterableDataset): |
| """Dataset for streaming text data.""" |
| |
| def __init__( |
| self, |
| dataset_name: str, |
| tokenizer, |
| max_seq_length: int, |
| streaming: bool = True, |
| text_column_name: str = "text", |
| buffer_size: int = 8192, |
| split: str = "train", |
| ): |
| self.tokenizer = tokenizer |
| self.max_seq_length = max_seq_length |
| self.text_column_name = text_column_name |
| |
| |
| self.dataset = load_dataset( |
| dataset_name, |
| split=split, |
| streaming=streaming, |
| ) |
| if streaming: |
| |
| self.dataset = self.dataset.shuffle(buffer_size=buffer_size) |
| |
| def __iter__(self): |
| buffer = [] |
| current_length = 0 |
| |
| for example in self.dataset: |
| text = example[self.text_column_name] |
| if not text or len(text.strip()) == 0: |
| continue |
| |
| tokenized = self.tokenizer( |
| text, |
| truncation=False, |
| return_attention_mask=False, |
| return_token_type_ids=False, |
| add_special_tokens=False, |
| ) |
| |
| ids = tokenized["input_ids"] |
| buffer.extend(ids) |
| |
| |
| while len(buffer) >= self.max_seq_length: |
| yield { |
| "input_ids": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long), |
| "labels": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long), |
| } |
| buffer = buffer[self.max_seq_length:] |
|
|
|
|
| def create_optimizer_and_scheduler( |
| model: nn.Module, |
| args: TrainingArguments, |
| num_training_steps: int |
| ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: |
| """Create optimizer and learning rate scheduler.""" |
| |
| |
| if hasattr(model, "get_trainable_parameters"): |
| optimizer_params = model.get_trainable_parameters() |
| logger.info(f"Training with {len(optimizer_params)} trainable parameters") |
| else: |
| |
| optimizer_params = [p for p in model.parameters() if p.requires_grad] |
| logger.info(f"Training with {len(optimizer_params)} parameters") |
| |
| |
| optimizer = AdamW( |
| optimizer_params, |
| lr=args.learning_rate, |
| betas=(args.adam_beta1, args.adam_beta2), |
| eps=args.adam_epsilon, |
| weight_decay=args.weight_decay, |
| ) |
| |
| |
| scheduler = get_scheduler( |
| name=args.lr_scheduler_type, |
| optimizer=optimizer, |
| num_warmup_steps=args.warmup_steps, |
| num_training_steps=num_training_steps, |
| ) |
| |
| return optimizer, scheduler |
|
|
|
|
| def train( |
| model_args: ModelArguments, |
| data_args: DataArguments, |
| training_args: TrainingArguments, |
| ): |
| """Main training function.""" |
| |
| |
| accelerator = Accelerator( |
| gradient_accumulation_steps=training_args.gradient_accumulation_steps, |
| mixed_precision=training_args.fp16 and "fp16" or training_args.bf16 and "bf16" or "no", |
| ) |
| |
| |
| logger.info(accelerator.state) |
| if accelerator.is_local_main_process: |
| logger.info(f"Model arguments: {model_args}") |
| logger.info(f"Data arguments: {data_args}") |
| logger.info(f"Training arguments: {training_args}") |
| |
| |
| set_seed(training_args.seed) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
| |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| config = OlmoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
| config.adapter_size = model_args.adapter_size |
| |
| |
| logger.info(f"Loading OlmoE model with adapters from {model_args.model_name_or_path}") |
| base_model = OlmoForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
| |
| |
| model = OlmoEWithAdaptersForCausalLM(config) |
| |
| |
| |
| model.load_state_dict(base_model.state_dict(), strict=False) |
| |
| |
| if model_args.freeze_base_model: |
| logger.info("Freezing base model parameters") |
| model.freeze_base_model() |
| |
| |
| logger.info(f"Loading dataset: {data_args.dataset_name}") |
| train_dataset = StreamingTextDataset( |
| dataset_name=data_args.dataset_name, |
| tokenizer=tokenizer, |
| max_seq_length=data_args.max_seq_length, |
| streaming=data_args.streaming, |
| buffer_size=data_args.streaming_buffer_size, |
| text_column_name=data_args.text_column_name, |
| ) |
| |
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, |
| mlm=False, |
| ) |
| |
| |
| train_dataloader = DataLoader( |
| train_dataset, |
| batch_size=training_args.per_device_train_batch_size, |
| collate_fn=data_collator, |
| num_workers=data_args.preprocessing_num_workers or 0, |
| ) |
| |
| |
| |
| num_update_steps_per_epoch = training_args.max_steps |
| num_training_steps = training_args.max_steps |
| |
| |
| optimizer, lr_scheduler = create_optimizer_and_scheduler( |
| model=model, |
| args=training_args, |
| num_training_steps=num_training_steps, |
| ) |
| |
| |
| model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
| model, optimizer, train_dataloader, lr_scheduler |
| ) |
| |
| |
| total_batch_size = ( |
| training_args.per_device_train_batch_size |
| * accelerator.num_processes |
| * training_args.gradient_accumulation_steps |
| ) |
| logger.info(f"Total batch size (with parallel & accumulation): {total_batch_size}") |
| |
| |
| logger.info(f"Number of training steps: {num_training_steps}") |
| logger.info(f"Number of warmup steps: {training_args.warmup_steps}") |
| |
| |
| progress_bar = tqdm( |
| range(num_training_steps), |
| disable=not accelerator.is_local_main_process, |
| desc="Training", |
| ) |
| completed_steps = 0 |
| starting_epoch = 0 |
| global_step = 0 |
| |
| |
| logger.info("Starting training...") |
| model.train() |
| |
| for step, batch in enumerate(train_dataloader): |
| |
| if starting_epoch > 0 and step < starting_epoch * num_update_steps_per_epoch: |
| progress_bar.update(1) |
| continue |
| |
| with accelerator.accumulate(model): |
| |
| outputs = model(**batch) |
| loss = outputs.loss |
| |
| |
| accelerator.backward(loss) |
| |
| |
| optimizer.step() |
| lr_scheduler.step() |
| optimizer.zero_grad() |
| |
| |
| progress_bar.update(1) |
| completed_steps += 1 |
| global_step += 1 |
| |
| |
| if global_step % training_args.logging_steps == 0: |
| |
| loss_value = accelerator.gather(loss).mean().item() |
| logger.info(f"Step {global_step}: loss = {loss_value:.4f}, lr = {lr_scheduler.get_last_lr()[0]:.8f}") |
| |
| |
| if hasattr(accelerator.trackers[0], "store"): |
| accelerator.trackers[0].store({ |
| "loss": loss_value, |
| "learning_rate": lr_scheduler.get_last_lr()[0], |
| "step": global_step, |
| }) |
| |
| |
| if training_args.save_steps > 0 and global_step % training_args.save_steps == 0: |
| if model_args.checkpoint_dir is not None: |
| output_dir = os.path.join(model_args.checkpoint_dir, f"checkpoint-{global_step}") |
| accelerator.save_state(output_dir) |
| logger.info(f"Saved checkpoint to {output_dir}") |
| |
| |
| if accelerator.is_main_process: |
| unwrapped_model = accelerator.unwrap_model(model) |
| unwrapped_model.save_pretrained( |
| output_dir, |
| is_main_process=accelerator.is_main_process, |
| save_function=accelerator.save, |
| ) |
| tokenizer.save_pretrained(output_dir) |
| |
| |
| if completed_steps >= num_training_steps: |
| break |
| |
| |
| if model_args.checkpoint_dir is not None: |
| output_dir = os.path.join(model_args.checkpoint_dir, "final-model") |
| accelerator.save_state(output_dir) |
| |
| |
| if accelerator.is_main_process: |
| unwrapped_model = accelerator.unwrap_model(model) |
| unwrapped_model.save_pretrained( |
| output_dir, |
| is_main_process=accelerator.is_main_process, |
| save_function=accelerator.save, |
| ) |
| tokenizer.save_pretrained(output_dir) |
| |
| logger.info(f"Saved final model to {output_dir}") |
| |
| logger.info("Training complete!") |
|
|
|
|
| def main(): |
| """Main entry point.""" |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| |
| |
| if model_args.checkpoint_dir is None: |
| model_args.checkpoint_dir = training_args.output_dir |
| os.makedirs(model_args.checkpoint_dir, exist_ok=True) |
| |
| |
| train(model_args, data_args, training_args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |