| """ |
| Fine-tuning data pipeline for JointFusionModel. |
| |
| Prepares datasets that yield {input_ids, attention_mask, tabular_features, labels} |
| for supervised fine-tuning of the joint Transformer + DCNv2(PLR) fusion model. |
| |
| Unlike pre-training (which packs sequences for 100% token utilization), fine-tuning |
| uses per-user sequences padded to a fixed length — each sample represents one user |
| with one label. |
| """ |
|
|
| import logging |
| from typing import Any, Dict, Sequence |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset as TorchDataset |
|
|
| from ..tokenizers.domain_tokenizer import DomainTokenizerBuilder |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class DomainFinetuneDataset(TorchDataset): |
| """Dataset for fine-tuning JointFusionModel on labeled user data. |
| |
| Each sample represents one user: |
| - input_ids: tokenized transaction sequence (padded/truncated to max_length) |
| - attention_mask: 1 for real tokens, 0 for padding |
| - tabular_features: numerical feature vector for the user |
| - labels: target label (float for binary, int for multiclass) |
| """ |
|
|
| def __init__(self, user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512): |
| assert len(user_sequences) == len(tabular_features) == len(labels), ( |
| f"Length mismatch: sequences={len(user_sequences)}, " |
| f"tabular={len(tabular_features)}, labels={len(labels)}" |
| ) |
| self.user_sequences = user_sequences |
| self.tabular_features = np.asarray(tabular_features, dtype=np.float32) |
| self.labels = np.asarray(labels) |
| self.builder = builder |
| self.hf_tokenizer = hf_tokenizer |
| self.max_length = max_length |
| if hf_tokenizer.pad_token_id is None: |
| raise ValueError("Tokenizer must have pad_token set.") |
|
|
| def __len__(self): |
| return len(self.user_sequences) |
|
|
| def __getitem__(self, idx): |
| events = self.user_sequences[idx] |
| token_strings = self.builder.tokenize_sequence(events, add_bos=True, add_eos=True) |
| encoding = self.hf_tokenizer( |
| " ".join(token_strings), max_length=self.max_length, |
| truncation=True, padding="max_length", add_special_tokens=False, return_tensors="pt", |
| ) |
| return { |
| "input_ids": encoding["input_ids"].squeeze(0), |
| "attention_mask": encoding["attention_mask"].squeeze(0), |
| "tabular_features": torch.tensor(self.tabular_features[idx], dtype=torch.float32), |
| "labels": torch.tensor(self.labels[idx], dtype=torch.float32), |
| } |
|
|
| def get_stats(self): |
| return { |
| "n_samples": len(self), "max_length": self.max_length, |
| "n_tabular_features": self.tabular_features.shape[1], |
| "label_distribution": { |
| "mean": float(self.labels.mean()), "std": float(self.labels.std()), |
| "min": float(self.labels.min()), "max": float(self.labels.max()), |
| }, |
| } |
|
|
|
|
| def prepare_finetune_dataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512): |
| """Convenience function to create a fine-tuning dataset.""" |
| ds = DomainFinetuneDataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length) |
| logger.info(f"Fine-tune dataset: {len(ds)} samples, max_length={max_length}, " |
| f"tabular_features={tabular_features.shape[1]}, label_mean={labels.mean():.3f}") |
| return ds |
|
|