rtferraz's picture
Phase 2D: Fine-tuning pipeline — DomainFinetuneDataset, finetune_domain_model, 139 total tests passing
256963c verified
"""
Fine-tuning data pipeline for JointFusionModel.
Prepares datasets that yield {input_ids, attention_mask, tabular_features, labels}
for supervised fine-tuning of the joint Transformer + DCNv2(PLR) fusion model.
Unlike pre-training (which packs sequences for 100% token utilization), fine-tuning
uses per-user sequences padded to a fixed length — each sample represents one user
with one label.
"""
import logging
from typing import Any, Dict, Sequence
import numpy as np
import torch
from torch.utils.data import Dataset as TorchDataset
from ..tokenizers.domain_tokenizer import DomainTokenizerBuilder
logger = logging.getLogger(__name__)
class DomainFinetuneDataset(TorchDataset):
"""Dataset for fine-tuning JointFusionModel on labeled user data.
Each sample represents one user:
- input_ids: tokenized transaction sequence (padded/truncated to max_length)
- attention_mask: 1 for real tokens, 0 for padding
- tabular_features: numerical feature vector for the user
- labels: target label (float for binary, int for multiclass)
"""
def __init__(self, user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512):
assert len(user_sequences) == len(tabular_features) == len(labels), (
f"Length mismatch: sequences={len(user_sequences)}, "
f"tabular={len(tabular_features)}, labels={len(labels)}"
)
self.user_sequences = user_sequences
self.tabular_features = np.asarray(tabular_features, dtype=np.float32)
self.labels = np.asarray(labels)
self.builder = builder
self.hf_tokenizer = hf_tokenizer
self.max_length = max_length
if hf_tokenizer.pad_token_id is None:
raise ValueError("Tokenizer must have pad_token set.")
def __len__(self):
return len(self.user_sequences)
def __getitem__(self, idx):
events = self.user_sequences[idx]
token_strings = self.builder.tokenize_sequence(events, add_bos=True, add_eos=True)
encoding = self.hf_tokenizer(
" ".join(token_strings), max_length=self.max_length,
truncation=True, padding="max_length", add_special_tokens=False, return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"tabular_features": torch.tensor(self.tabular_features[idx], dtype=torch.float32),
"labels": torch.tensor(self.labels[idx], dtype=torch.float32),
}
def get_stats(self):
return {
"n_samples": len(self), "max_length": self.max_length,
"n_tabular_features": self.tabular_features.shape[1],
"label_distribution": {
"mean": float(self.labels.mean()), "std": float(self.labels.std()),
"min": float(self.labels.min()), "max": float(self.labels.max()),
},
}
def prepare_finetune_dataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512):
"""Convenience function to create a fine-tuning dataset."""
ds = DomainFinetuneDataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length)
logger.info(f"Fine-tune dataset: {len(ds)} samples, max_length={max_length}, "
f"tabular_features={tabular_features.shape[1]}, label_mean={labels.mean():.3f}")
return ds