""" Fine-tuning data pipeline for JointFusionModel. Prepares datasets that yield {input_ids, attention_mask, tabular_features, labels} for supervised fine-tuning of the joint Transformer + DCNv2(PLR) fusion model. Unlike pre-training (which packs sequences for 100% token utilization), fine-tuning uses per-user sequences padded to a fixed length — each sample represents one user with one label. """ import logging from typing import Any, Dict, Sequence import numpy as np import torch from torch.utils.data import Dataset as TorchDataset from ..tokenizers.domain_tokenizer import DomainTokenizerBuilder logger = logging.getLogger(__name__) class DomainFinetuneDataset(TorchDataset): """Dataset for fine-tuning JointFusionModel on labeled user data. Each sample represents one user: - input_ids: tokenized transaction sequence (padded/truncated to max_length) - attention_mask: 1 for real tokens, 0 for padding - tabular_features: numerical feature vector for the user - labels: target label (float for binary, int for multiclass) """ def __init__(self, user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512): assert len(user_sequences) == len(tabular_features) == len(labels), ( f"Length mismatch: sequences={len(user_sequences)}, " f"tabular={len(tabular_features)}, labels={len(labels)}" ) self.user_sequences = user_sequences self.tabular_features = np.asarray(tabular_features, dtype=np.float32) self.labels = np.asarray(labels) self.builder = builder self.hf_tokenizer = hf_tokenizer self.max_length = max_length if hf_tokenizer.pad_token_id is None: raise ValueError("Tokenizer must have pad_token set.") def __len__(self): return len(self.user_sequences) def __getitem__(self, idx): events = self.user_sequences[idx] token_strings = self.builder.tokenize_sequence(events, add_bos=True, add_eos=True) encoding = self.hf_tokenizer( " ".join(token_strings), max_length=self.max_length, truncation=True, padding="max_length", add_special_tokens=False, return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(0), "attention_mask": encoding["attention_mask"].squeeze(0), "tabular_features": torch.tensor(self.tabular_features[idx], dtype=torch.float32), "labels": torch.tensor(self.labels[idx], dtype=torch.float32), } def get_stats(self): return { "n_samples": len(self), "max_length": self.max_length, "n_tabular_features": self.tabular_features.shape[1], "label_distribution": { "mean": float(self.labels.mean()), "std": float(self.labels.std()), "min": float(self.labels.min()), "max": float(self.labels.max()), }, } def prepare_finetune_dataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512): """Convenience function to create a fine-tuning dataset.""" ds = DomainFinetuneDataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length) logger.info(f"Fine-tune dataset: {len(ds)} samples, max_length={max_length}, " f"tabular_features={tabular_features.shape[1]}, label_mean={labels.mean():.3f}") return ds