File size: 3,347 Bytes
46a6d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Fine-tuning function for JointFusionModel.

Uses HF Trainer Pattern A — Trainer inspects JointFusionModel.forward() signature,
sees tabular_features, and auto-passes it from dataset. No Trainer subclass needed.
"""

import logging
from typing import Optional

from torch.utils.data import Dataset as TorchDataset
from transformers import Trainer, TrainingArguments

logger = logging.getLogger(__name__)


def finetune_domain_model(
    model, train_dataset: TorchDataset, eval_dataset: Optional[TorchDataset] = None,
    output_dir: str = "./domain_finetune_checkpoints", hub_model_id: Optional[str] = None,
    num_epochs: int = 5, per_device_batch_size: int = 32, gradient_accumulation_steps: int = 1,
    learning_rate: float = 1e-4, lr_scheduler_type: str = "cosine",
    warmup_steps: int = 100, weight_decay: float = 0.01, max_grad_norm: float = 1.0,
    bf16: bool = False, fp16: bool = False, logging_steps: int = 50,
    save_steps: int = 500, save_strategy: str = "steps", eval_steps: int = 500,
    save_total_limit: int = 3, dataloader_num_workers: int = 4, report_to: str = "none",
    run_name: Optional[str] = None, seed: int = 42, gradient_checkpointing: bool = False,
    resume_from_checkpoint: Optional[str] = None, **extra_training_args,
) -> Trainer:
    """Fine-tune a JointFusionModel with HF Trainer.

    The Trainer auto-passes tabular_features from dataset to model because
    it inspects forward() signature (Pattern A — no subclass needed).

    Dataset must yield: {input_ids, attention_mask, tabular_features, labels}.
    """
    push_to_hub = hub_model_id is not None

    training_args = TrainingArguments(
        output_dir=output_dir, num_train_epochs=num_epochs,
        per_device_train_batch_size=per_device_batch_size,
        per_device_eval_batch_size=per_device_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate, lr_scheduler_type=lr_scheduler_type,
        warmup_steps=warmup_steps, weight_decay=weight_decay, max_grad_norm=max_grad_norm,
        bf16=bf16, fp16=fp16,
        logging_strategy="steps", logging_steps=logging_steps,
        logging_first_step=True, disable_tqdm=True,
        eval_strategy="steps" if eval_dataset else "no",
        eval_steps=eval_steps if eval_dataset else None,
        save_strategy=save_strategy, save_steps=save_steps, save_total_limit=save_total_limit,
        push_to_hub=push_to_hub, hub_model_id=hub_model_id if push_to_hub else None,
        dataloader_num_workers=dataloader_num_workers, report_to=report_to,
        run_name=run_name, seed=seed, gradient_checkpointing=gradient_checkpointing,
        remove_unused_columns=True, **extra_training_args,
    )

    n_params = sum(p.numel() for p in model.parameters())
    logger.info(f"=== Domain Fine-Tuning (Joint Fusion) ===")
    logger.info(f"  Model params: {n_params:,}, Train samples: {len(train_dataset):,}")
    logger.info(f"  Batch: {per_device_batch_size}x{gradient_accumulation_steps}, "
                f"Epochs: {num_epochs}, LR: {learning_rate} ({lr_scheduler_type})")

    trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    if push_to_hub:
        trainer.push_to_hub()

    return trainer