File size: 4,621 Bytes
6ccb9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Pre-training function for DomainTransformer.

Uses HuggingFace Trainer with DataCollatorForLanguageModeling(mlm=False)
which automatically sets labels = input_ids and masks padding with -100.

Usage:
    from domain_tokenizer.training import pretrain_domain_model, prepare_clm_dataset
    dataset = prepare_clm_dataset(user_sequences, builder, hf_tokenizer, block_size=512)
    config = DomainTransformerConfig.from_preset("24m", vocab_size=hf_tokenizer.vocab_size)
    model = DomainTransformerForCausalLM(config)
    pretrain_domain_model(model, hf_tokenizer, dataset)
"""

import logging
from typing import Optional

from datasets import Dataset as HFDataset
from transformers import (
    DataCollatorForLanguageModeling,
    PreTrainedTokenizerFast,
    Trainer,
    TrainingArguments,
)

logger = logging.getLogger(__name__)


def pretrain_domain_model(
    model,
    tokenizer: PreTrainedTokenizerFast,
    train_dataset: HFDataset,
    eval_dataset: Optional[HFDataset] = None,
    output_dir: str = "./domain_pretrain_checkpoints",
    hub_model_id: Optional[str] = None,
    num_epochs: int = 10,
    per_device_batch_size: int = 32,
    gradient_accumulation_steps: int = 4,
    learning_rate: float = 3e-4,
    lr_scheduler_type: str = "cosine",
    warmup_steps: int = 500,
    weight_decay: float = 0.01,
    max_grad_norm: float = 1.0,
    bf16: bool = False,
    fp16: bool = False,
    logging_steps: int = 50,
    save_steps: int = 500,
    eval_steps: int = 500,
    save_total_limit: int = 3,
    dataloader_num_workers: int = 4,
    report_to: str = "none",
    run_name: Optional[str] = None,
    seed: int = 42,
    gradient_checkpointing: bool = False,
    resume_from_checkpoint: Optional[str] = None,
    **extra_training_args,
) -> Trainer:
    """Pre-train a DomainTransformerForCausalLM with HF Trainer.

    The dataset should be packed via prepare_clm_dataset() for 100% token utilization.

    Returns:
        The Trainer instance (for inspection, continued training, etc.).
    """
    if tokenizer.pad_token_id is None:
        raise ValueError(
            "Tokenizer must have pad_token set. "
            "DomainTokenizerBuilder.build() should set this automatically."
        )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    push_to_hub = hub_model_id is not None

    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=per_device_batch_size,
        per_device_eval_batch_size=per_device_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        lr_scheduler_type=lr_scheduler_type,
        warmup_steps=warmup_steps,
        weight_decay=weight_decay,
        max_grad_norm=max_grad_norm,
        bf16=bf16, fp16=fp16,
        logging_strategy="steps",
        logging_steps=logging_steps,
        logging_first_step=True,
        disable_tqdm=True,
        eval_strategy="steps" if eval_dataset else "no",
        eval_steps=eval_steps if eval_dataset else None,
        save_strategy="steps",
        save_steps=save_steps,
        save_total_limit=save_total_limit,
        push_to_hub=push_to_hub,
        hub_model_id=hub_model_id if push_to_hub else None,
        dataloader_num_workers=dataloader_num_workers,
        report_to=report_to,
        run_name=run_name,
        seed=seed,
        gradient_checkpointing=gradient_checkpointing,
        remove_unused_columns=True,
        **extra_training_args,
    )

    effective_batch = per_device_batch_size * gradient_accumulation_steps
    n_params = sum(p.numel() for p in model.parameters())

    logger.info(f"=== Domain Pre-Training ===")
    logger.info(f"  Model params: {n_params:,}")
    logger.info(f"  Train samples: {len(train_dataset):,}")
    logger.info(f"  Block size: {len(train_dataset[0]['input_ids'])}")
    logger.info(f"  Batch size: {per_device_batch_size} x {gradient_accumulation_steps} = {effective_batch}")
    logger.info(f"  Epochs: {num_epochs}, LR: {learning_rate} ({lr_scheduler_type})")
    logger.info(f"  Push to hub: {hub_model_id if push_to_hub else 'disabled'}")

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        processing_class=tokenizer,
    )

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    if push_to_hub:
        logger.info(f"Pushing model to hub: {hub_model_id}")
        trainer.push_to_hub()

    return trainer