File size: 4,621 Bytes
6ccb9e6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
Pre-training function for DomainTransformer.
Uses HuggingFace Trainer with DataCollatorForLanguageModeling(mlm=False)
which automatically sets labels = input_ids and masks padding with -100.
Usage:
from domain_tokenizer.training import pretrain_domain_model, prepare_clm_dataset
dataset = prepare_clm_dataset(user_sequences, builder, hf_tokenizer, block_size=512)
config = DomainTransformerConfig.from_preset("24m", vocab_size=hf_tokenizer.vocab_size)
model = DomainTransformerForCausalLM(config)
pretrain_domain_model(model, hf_tokenizer, dataset)
"""
import logging
from typing import Optional
from datasets import Dataset as HFDataset
from transformers import (
DataCollatorForLanguageModeling,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
)
logger = logging.getLogger(__name__)
def pretrain_domain_model(
model,
tokenizer: PreTrainedTokenizerFast,
train_dataset: HFDataset,
eval_dataset: Optional[HFDataset] = None,
output_dir: str = "./domain_pretrain_checkpoints",
hub_model_id: Optional[str] = None,
num_epochs: int = 10,
per_device_batch_size: int = 32,
gradient_accumulation_steps: int = 4,
learning_rate: float = 3e-4,
lr_scheduler_type: str = "cosine",
warmup_steps: int = 500,
weight_decay: float = 0.01,
max_grad_norm: float = 1.0,
bf16: bool = False,
fp16: bool = False,
logging_steps: int = 50,
save_steps: int = 500,
eval_steps: int = 500,
save_total_limit: int = 3,
dataloader_num_workers: int = 4,
report_to: str = "none",
run_name: Optional[str] = None,
seed: int = 42,
gradient_checkpointing: bool = False,
resume_from_checkpoint: Optional[str] = None,
**extra_training_args,
) -> Trainer:
"""Pre-train a DomainTransformerForCausalLM with HF Trainer.
The dataset should be packed via prepare_clm_dataset() for 100% token utilization.
Returns:
The Trainer instance (for inspection, continued training, etc.).
"""
if tokenizer.pad_token_id is None:
raise ValueError(
"Tokenizer must have pad_token set. "
"DomainTokenizerBuilder.build() should set this automatically."
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
push_to_hub = hub_model_id is not None
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=per_device_batch_size,
per_device_eval_batch_size=per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
bf16=bf16, fp16=fp16,
logging_strategy="steps",
logging_steps=logging_steps,
logging_first_step=True,
disable_tqdm=True,
eval_strategy="steps" if eval_dataset else "no",
eval_steps=eval_steps if eval_dataset else None,
save_strategy="steps",
save_steps=save_steps,
save_total_limit=save_total_limit,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id if push_to_hub else None,
dataloader_num_workers=dataloader_num_workers,
report_to=report_to,
run_name=run_name,
seed=seed,
gradient_checkpointing=gradient_checkpointing,
remove_unused_columns=True,
**extra_training_args,
)
effective_batch = per_device_batch_size * gradient_accumulation_steps
n_params = sum(p.numel() for p in model.parameters())
logger.info(f"=== Domain Pre-Training ===")
logger.info(f" Model params: {n_params:,}")
logger.info(f" Train samples: {len(train_dataset):,}")
logger.info(f" Block size: {len(train_dataset[0]['input_ids'])}")
logger.info(f" Batch size: {per_device_batch_size} x {gradient_accumulation_steps} = {effective_batch}")
logger.info(f" Epochs: {num_epochs}, LR: {learning_rate} ({lr_scheduler_type})")
logger.info(f" Push to hub: {hub_model_id if push_to_hub else 'disabled'}")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
processing_class=tokenizer,
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
if push_to_hub:
logger.info(f"Pushing model to hub: {hub_model_id}")
trainer.push_to_hub()
return trainer
|