File size: 3,447 Bytes
256963c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Fine-tuning data pipeline for JointFusionModel.

Prepares datasets that yield {input_ids, attention_mask, tabular_features, labels}
for supervised fine-tuning of the joint Transformer + DCNv2(PLR) fusion model.

Unlike pre-training (which packs sequences for 100% token utilization), fine-tuning
uses per-user sequences padded to a fixed length — each sample represents one user
with one label.
"""

import logging
from typing import Any, Dict, Sequence

import numpy as np
import torch
from torch.utils.data import Dataset as TorchDataset

from ..tokenizers.domain_tokenizer import DomainTokenizerBuilder

logger = logging.getLogger(__name__)


class DomainFinetuneDataset(TorchDataset):
    """Dataset for fine-tuning JointFusionModel on labeled user data.

    Each sample represents one user:
      - input_ids: tokenized transaction sequence (padded/truncated to max_length)
      - attention_mask: 1 for real tokens, 0 for padding
      - tabular_features: numerical feature vector for the user
      - labels: target label (float for binary, int for multiclass)
    """

    def __init__(self, user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512):
        assert len(user_sequences) == len(tabular_features) == len(labels), (
            f"Length mismatch: sequences={len(user_sequences)}, "
            f"tabular={len(tabular_features)}, labels={len(labels)}"
        )
        self.user_sequences = user_sequences
        self.tabular_features = np.asarray(tabular_features, dtype=np.float32)
        self.labels = np.asarray(labels)
        self.builder = builder
        self.hf_tokenizer = hf_tokenizer
        self.max_length = max_length
        if hf_tokenizer.pad_token_id is None:
            raise ValueError("Tokenizer must have pad_token set.")

    def __len__(self):
        return len(self.user_sequences)

    def __getitem__(self, idx):
        events = self.user_sequences[idx]
        token_strings = self.builder.tokenize_sequence(events, add_bos=True, add_eos=True)
        encoding = self.hf_tokenizer(
            " ".join(token_strings), max_length=self.max_length,
            truncation=True, padding="max_length", add_special_tokens=False, return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "tabular_features": torch.tensor(self.tabular_features[idx], dtype=torch.float32),
            "labels": torch.tensor(self.labels[idx], dtype=torch.float32),
        }

    def get_stats(self):
        return {
            "n_samples": len(self), "max_length": self.max_length,
            "n_tabular_features": self.tabular_features.shape[1],
            "label_distribution": {
                "mean": float(self.labels.mean()), "std": float(self.labels.std()),
                "min": float(self.labels.min()), "max": float(self.labels.max()),
            },
        }


def prepare_finetune_dataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512):
    """Convenience function to create a fine-tuning dataset."""
    ds = DomainFinetuneDataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length)
    logger.info(f"Fine-tune dataset: {len(ds)} samples, max_length={max_length}, "
                f"tabular_features={tabular_features.shape[1]}, label_mean={labels.mean():.3f}")
    return ds