rtferraz's picture
Add DCNv2 + JointFusionModel (nuFormer-style Transformer + tabular fusion)
e881ea3 verified
"""
Joint Fusion Model — nuFormer-style architecture combining:
1. Transaction Transformer (pre-trained DomainTransformer) -> user embedding
2. DCNv2 with PLR embeddings (tabular features) -> feature embedding
3. Shared MLP head -> prediction
Architecture follows Nubank nuFormer (arXiv:2507.23267).
This is the fine-tuning architecture. Pre-training uses DomainTransformerForCausalLM alone.
"""
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .plr_embeddings import PeriodicLinearReLU
class DCNv2CrossLayer(nn.Module):
"""Single cross layer from Deep & Cross Network V2 (Wang et al. 2021).
Computes: x_{l+1} = x_0 * (W_l * x_l + b_l) + x_l
"""
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Linear(dim, dim, bias=True)
def forward(self, x0: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return x0 * self.weight(x) + x
class DCNv2(nn.Module):
"""Deep & Cross Network V2 for tabular feature interaction."""
def __init__(self, input_dim: int, cross_layers: int = 3, deep_layers: int = 2,
deep_dim: int = 256, dropout: float = 0.1):
super().__init__()
self.cross_layers = nn.ModuleList([DCNv2CrossLayer(input_dim) for _ in range(cross_layers)])
layers = []
in_dim = input_dim
for _ in range(deep_layers):
layers.extend([nn.Linear(in_dim, deep_dim), nn.ReLU(), nn.Dropout(dropout)])
in_dim = deep_dim
self.deep_network = nn.Sequential(*layers)
self.output_dim = deep_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
x0 = x
cross_out = x
for cross_layer in self.cross_layers:
cross_out = cross_layer(x0, cross_out)
return self.deep_network(cross_out)
class JointFusionModel(nn.Module):
"""nuFormer-style joint fusion: Transaction Transformer + DCNv2(PLR).
Architecture:
Transaction Sequence -> Pre-trained DomainTransformer -> user_embedding
Tabular Features -> PLR -> flatten -> DCNv2 -> tab_embedding
Concatenate -> MLP Head -> prediction
"""
def __init__(self, transformer_model, n_tabular_features: int, n_classes: int = 1,
plr_frequencies: int = 64, plr_embedding_dim: int = 64,
dcn_cross_layers: int = 3, dcn_deep_layers: int = 2,
dcn_deep_dim: int = 256, head_hidden_dim: int = 256, dropout: float = 0.1):
super().__init__()
self.transformer = transformer_model
transformer_dim = transformer_model.config.hidden_size
self.plr = PeriodicLinearReLU(n_features=n_tabular_features,
n_frequencies=plr_frequencies, embedding_dim=plr_embedding_dim)
plr_flat_dim = n_tabular_features * plr_embedding_dim
self.dcn = DCNv2(input_dim=plr_flat_dim, cross_layers=dcn_cross_layers,
deep_layers=dcn_deep_layers, deep_dim=dcn_deep_dim, dropout=dropout)
combined_dim = transformer_dim + dcn_deep_dim
self.head = nn.Sequential(
nn.Linear(combined_dim, head_hidden_dim), nn.ReLU(),
nn.Dropout(dropout), nn.Linear(head_hidden_dim, n_classes),
)
self.n_classes = n_classes
def forward(self, input_ids, attention_mask=None, tabular_features=None, labels=None):
user_embedding = self.transformer.get_user_embedding(input_ids=input_ids, attention_mask=attention_mask)
tab_embedded = self.plr(tabular_features)
tab_flat = tab_embedded.reshape(tab_embedded.size(0), -1)
tab_output = self.dcn(tab_flat)
combined = torch.cat([user_embedding, tab_output], dim=-1)
logits = self.head(combined)
loss = None
if labels is not None:
if self.n_classes == 1:
loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
else:
loss = F.cross_entropy(logits, labels.long())
return {"loss": loss, "logits": logits}