| """ |
| Pre-training function for DomainTransformer. |
| |
| Uses HuggingFace Trainer with DataCollatorForLanguageModeling(mlm=False) |
| which automatically sets labels = input_ids and masks padding with -100. |
| |
| Usage: |
| from domain_tokenizer.training import pretrain_domain_model, prepare_clm_dataset |
| dataset = prepare_clm_dataset(user_sequences, builder, hf_tokenizer, block_size=512) |
| config = DomainTransformerConfig.from_preset("24m", vocab_size=hf_tokenizer.vocab_size) |
| model = DomainTransformerForCausalLM(config) |
| pretrain_domain_model(model, hf_tokenizer, dataset) |
| """ |
|
|
| import logging |
| from typing import Optional |
|
|
| from datasets import Dataset as HFDataset |
| from transformers import ( |
| DataCollatorForLanguageModeling, |
| PreTrainedTokenizerFast, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def pretrain_domain_model( |
| model, |
| tokenizer: PreTrainedTokenizerFast, |
| train_dataset: HFDataset, |
| eval_dataset: Optional[HFDataset] = None, |
| output_dir: str = "./domain_pretrain_checkpoints", |
| hub_model_id: Optional[str] = None, |
| num_epochs: int = 10, |
| per_device_batch_size: int = 32, |
| gradient_accumulation_steps: int = 4, |
| learning_rate: float = 3e-4, |
| lr_scheduler_type: str = "cosine", |
| warmup_steps: int = 500, |
| weight_decay: float = 0.01, |
| max_grad_norm: float = 1.0, |
| bf16: bool = False, |
| fp16: bool = False, |
| logging_steps: int = 50, |
| save_steps: int = 500, |
| eval_steps: int = 500, |
| save_total_limit: int = 3, |
| dataloader_num_workers: int = 4, |
| report_to: str = "none", |
| run_name: Optional[str] = None, |
| seed: int = 42, |
| gradient_checkpointing: bool = False, |
| resume_from_checkpoint: Optional[str] = None, |
| **extra_training_args, |
| ) -> Trainer: |
| """Pre-train a DomainTransformerForCausalLM with HF Trainer. |
| |
| The dataset should be packed via prepare_clm_dataset() for 100% token utilization. |
| |
| Returns: |
| The Trainer instance (for inspection, continued training, etc.). |
| """ |
| if tokenizer.pad_token_id is None: |
| raise ValueError( |
| "Tokenizer must have pad_token set. " |
| "DomainTokenizerBuilder.build() should set this automatically." |
| ) |
|
|
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
| push_to_hub = hub_model_id is not None |
|
|
| training_args = TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=num_epochs, |
| per_device_train_batch_size=per_device_batch_size, |
| per_device_eval_batch_size=per_device_batch_size, |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| learning_rate=learning_rate, |
| lr_scheduler_type=lr_scheduler_type, |
| warmup_steps=warmup_steps, |
| weight_decay=weight_decay, |
| max_grad_norm=max_grad_norm, |
| bf16=bf16, fp16=fp16, |
| logging_strategy="steps", |
| logging_steps=logging_steps, |
| logging_first_step=True, |
| disable_tqdm=True, |
| eval_strategy="steps" if eval_dataset else "no", |
| eval_steps=eval_steps if eval_dataset else None, |
| save_strategy="steps", |
| save_steps=save_steps, |
| save_total_limit=save_total_limit, |
| push_to_hub=push_to_hub, |
| hub_model_id=hub_model_id if push_to_hub else None, |
| dataloader_num_workers=dataloader_num_workers, |
| report_to=report_to, |
| run_name=run_name, |
| seed=seed, |
| gradient_checkpointing=gradient_checkpointing, |
| remove_unused_columns=True, |
| **extra_training_args, |
| ) |
|
|
| effective_batch = per_device_batch_size * gradient_accumulation_steps |
| n_params = sum(p.numel() for p in model.parameters()) |
|
|
| logger.info(f"=== Domain Pre-Training ===") |
| logger.info(f" Model params: {n_params:,}") |
| logger.info(f" Train samples: {len(train_dataset):,}") |
| logger.info(f" Block size: {len(train_dataset[0]['input_ids'])}") |
| logger.info(f" Batch size: {per_device_batch_size} x {gradient_accumulation_steps} = {effective_batch}") |
| logger.info(f" Epochs: {num_epochs}, LR: {learning_rate} ({lr_scheduler_type})") |
| logger.info(f" Push to hub: {hub_model_id if push_to_hub else 'disabled'}") |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| data_collator=data_collator, |
| processing_class=tokenizer, |
| ) |
|
|
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
|
|
| if push_to_hub: |
| logger.info(f"Pushing model to hub: {hub_model_id}") |
| trainer.push_to_hub() |
|
|
| return trainer |
|
|