rtferraz commited on
Commit
e881ea3
·
verified ·
1 Parent(s): d685c0e

Add DCNv2 + JointFusionModel (nuFormer-style Transformer + tabular fusion)

Browse files
src/domain_tokenizer/models/joint_fusion.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Joint Fusion Model — nuFormer-style architecture combining:
3
+ 1. Transaction Transformer (pre-trained DomainTransformer) -> user embedding
4
+ 2. DCNv2 with PLR embeddings (tabular features) -> feature embedding
5
+ 3. Shared MLP head -> prediction
6
+
7
+ Architecture follows Nubank nuFormer (arXiv:2507.23267).
8
+ This is the fine-tuning architecture. Pre-training uses DomainTransformerForCausalLM alone.
9
+ """
10
+
11
+ from typing import Dict, Optional
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from .plr_embeddings import PeriodicLinearReLU
18
+
19
+
20
+ class DCNv2CrossLayer(nn.Module):
21
+ """Single cross layer from Deep & Cross Network V2 (Wang et al. 2021).
22
+ Computes: x_{l+1} = x_0 * (W_l * x_l + b_l) + x_l
23
+ """
24
+ def __init__(self, dim: int):
25
+ super().__init__()
26
+ self.weight = nn.Linear(dim, dim, bias=True)
27
+
28
+ def forward(self, x0: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
29
+ return x0 * self.weight(x) + x
30
+
31
+
32
+ class DCNv2(nn.Module):
33
+ """Deep & Cross Network V2 for tabular feature interaction."""
34
+ def __init__(self, input_dim: int, cross_layers: int = 3, deep_layers: int = 2,
35
+ deep_dim: int = 256, dropout: float = 0.1):
36
+ super().__init__()
37
+ self.cross_layers = nn.ModuleList([DCNv2CrossLayer(input_dim) for _ in range(cross_layers)])
38
+ layers = []
39
+ in_dim = input_dim
40
+ for _ in range(deep_layers):
41
+ layers.extend([nn.Linear(in_dim, deep_dim), nn.ReLU(), nn.Dropout(dropout)])
42
+ in_dim = deep_dim
43
+ self.deep_network = nn.Sequential(*layers)
44
+ self.output_dim = deep_dim
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ x0 = x
48
+ cross_out = x
49
+ for cross_layer in self.cross_layers:
50
+ cross_out = cross_layer(x0, cross_out)
51
+ return self.deep_network(cross_out)
52
+
53
+
54
+ class JointFusionModel(nn.Module):
55
+ """nuFormer-style joint fusion: Transaction Transformer + DCNv2(PLR).
56
+
57
+ Architecture:
58
+ Transaction Sequence -> Pre-trained DomainTransformer -> user_embedding
59
+ Tabular Features -> PLR -> flatten -> DCNv2 -> tab_embedding
60
+ Concatenate -> MLP Head -> prediction
61
+ """
62
+ def __init__(self, transformer_model, n_tabular_features: int, n_classes: int = 1,
63
+ plr_frequencies: int = 64, plr_embedding_dim: int = 64,
64
+ dcn_cross_layers: int = 3, dcn_deep_layers: int = 2,
65
+ dcn_deep_dim: int = 256, head_hidden_dim: int = 256, dropout: float = 0.1):
66
+ super().__init__()
67
+ self.transformer = transformer_model
68
+ transformer_dim = transformer_model.config.hidden_size
69
+
70
+ self.plr = PeriodicLinearReLU(n_features=n_tabular_features,
71
+ n_frequencies=plr_frequencies, embedding_dim=plr_embedding_dim)
72
+ plr_flat_dim = n_tabular_features * plr_embedding_dim
73
+ self.dcn = DCNv2(input_dim=plr_flat_dim, cross_layers=dcn_cross_layers,
74
+ deep_layers=dcn_deep_layers, deep_dim=dcn_deep_dim, dropout=dropout)
75
+
76
+ combined_dim = transformer_dim + dcn_deep_dim
77
+ self.head = nn.Sequential(
78
+ nn.Linear(combined_dim, head_hidden_dim), nn.ReLU(),
79
+ nn.Dropout(dropout), nn.Linear(head_hidden_dim, n_classes),
80
+ )
81
+ self.n_classes = n_classes
82
+
83
+ def forward(self, input_ids, attention_mask=None, tabular_features=None, labels=None):
84
+ user_embedding = self.transformer.get_user_embedding(input_ids=input_ids, attention_mask=attention_mask)
85
+ tab_embedded = self.plr(tabular_features)
86
+ tab_flat = tab_embedded.reshape(tab_embedded.size(0), -1)
87
+ tab_output = self.dcn(tab_flat)
88
+
89
+ combined = torch.cat([user_embedding, tab_output], dim=-1)
90
+ logits = self.head(combined)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ if self.n_classes == 1:
95
+ loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
96
+ else:
97
+ loss = F.cross_entropy(logits, labels.long())
98
+
99
+ return {"loss": loss, "logits": logits}