rtferraz commited on
Commit
256963c
·
verified ·
1 Parent(s): f580186

Phase 2D: Fine-tuning pipeline — DomainFinetuneDataset, finetune_domain_model, 139 total tests passing

Browse files

Implements the supervised fine-tuning pipeline for JointFusionModel:
- finetune_data.py: DomainFinetuneDataset (per-user padded sequences + tabular features + labels)
- finetune.py: finetune_domain_model (HF Trainer Pattern A — auto-detects tabular_features)
- test_finetune.py: 15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass
- All 139 tests passing (72 tokenizer + 33 model + 19 pre-training + 15 fine-tuning)

src/domain_tokenizer/training/finetune_data.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tuning data pipeline for JointFusionModel.
3
+
4
+ Prepares datasets that yield {input_ids, attention_mask, tabular_features, labels}
5
+ for supervised fine-tuning of the joint Transformer + DCNv2(PLR) fusion model.
6
+
7
+ Unlike pre-training (which packs sequences for 100% token utilization), fine-tuning
8
+ uses per-user sequences padded to a fixed length — each sample represents one user
9
+ with one label.
10
+ """
11
+
12
+ import logging
13
+ from typing import Any, Dict, Sequence
14
+
15
+ import numpy as np
16
+ import torch
17
+ from torch.utils.data import Dataset as TorchDataset
18
+
19
+ from ..tokenizers.domain_tokenizer import DomainTokenizerBuilder
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DomainFinetuneDataset(TorchDataset):
25
+ """Dataset for fine-tuning JointFusionModel on labeled user data.
26
+
27
+ Each sample represents one user:
28
+ - input_ids: tokenized transaction sequence (padded/truncated to max_length)
29
+ - attention_mask: 1 for real tokens, 0 for padding
30
+ - tabular_features: numerical feature vector for the user
31
+ - labels: target label (float for binary, int for multiclass)
32
+ """
33
+
34
+ def __init__(self, user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512):
35
+ assert len(user_sequences) == len(tabular_features) == len(labels), (
36
+ f"Length mismatch: sequences={len(user_sequences)}, "
37
+ f"tabular={len(tabular_features)}, labels={len(labels)}"
38
+ )
39
+ self.user_sequences = user_sequences
40
+ self.tabular_features = np.asarray(tabular_features, dtype=np.float32)
41
+ self.labels = np.asarray(labels)
42
+ self.builder = builder
43
+ self.hf_tokenizer = hf_tokenizer
44
+ self.max_length = max_length
45
+ if hf_tokenizer.pad_token_id is None:
46
+ raise ValueError("Tokenizer must have pad_token set.")
47
+
48
+ def __len__(self):
49
+ return len(self.user_sequences)
50
+
51
+ def __getitem__(self, idx):
52
+ events = self.user_sequences[idx]
53
+ token_strings = self.builder.tokenize_sequence(events, add_bos=True, add_eos=True)
54
+ encoding = self.hf_tokenizer(
55
+ " ".join(token_strings), max_length=self.max_length,
56
+ truncation=True, padding="max_length", add_special_tokens=False, return_tensors="pt",
57
+ )
58
+ return {
59
+ "input_ids": encoding["input_ids"].squeeze(0),
60
+ "attention_mask": encoding["attention_mask"].squeeze(0),
61
+ "tabular_features": torch.tensor(self.tabular_features[idx], dtype=torch.float32),
62
+ "labels": torch.tensor(self.labels[idx], dtype=torch.float32),
63
+ }
64
+
65
+ def get_stats(self):
66
+ return {
67
+ "n_samples": len(self), "max_length": self.max_length,
68
+ "n_tabular_features": self.tabular_features.shape[1],
69
+ "label_distribution": {
70
+ "mean": float(self.labels.mean()), "std": float(self.labels.std()),
71
+ "min": float(self.labels.min()), "max": float(self.labels.max()),
72
+ },
73
+ }
74
+
75
+
76
+ def prepare_finetune_dataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length=512):
77
+ """Convenience function to create a fine-tuning dataset."""
78
+ ds = DomainFinetuneDataset(user_sequences, tabular_features, labels, builder, hf_tokenizer, max_length)
79
+ logger.info(f"Fine-tune dataset: {len(ds)} samples, max_length={max_length}, "
80
+ f"tabular_features={tabular_features.shape[1]}, label_mean={labels.mean():.3f}")
81
+ return ds