| """ |
| Joint Fusion Model — nuFormer-style architecture combining: |
| 1. Transaction Transformer (pre-trained DomainTransformer) -> user embedding |
| 2. DCNv2 with PLR embeddings (tabular features) -> feature embedding |
| 3. Shared MLP head -> prediction |
| |
| Architecture follows Nubank nuFormer (arXiv:2507.23267). |
| This is the fine-tuning architecture. Pre-training uses DomainTransformerForCausalLM alone. |
| """ |
|
|
| from typing import Dict, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .plr_embeddings import PeriodicLinearReLU |
|
|
|
|
| class DCNv2CrossLayer(nn.Module): |
| """Single cross layer from Deep & Cross Network V2 (Wang et al. 2021). |
| Computes: x_{l+1} = x_0 * (W_l * x_l + b_l) + x_l |
| """ |
| def __init__(self, dim: int): |
| super().__init__() |
| self.weight = nn.Linear(dim, dim, bias=True) |
|
|
| def forward(self, x0: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
| return x0 * self.weight(x) + x |
|
|
|
|
| class DCNv2(nn.Module): |
| """Deep & Cross Network V2 for tabular feature interaction.""" |
| def __init__(self, input_dim: int, cross_layers: int = 3, deep_layers: int = 2, |
| deep_dim: int = 256, dropout: float = 0.1): |
| super().__init__() |
| self.cross_layers = nn.ModuleList([DCNv2CrossLayer(input_dim) for _ in range(cross_layers)]) |
| layers = [] |
| in_dim = input_dim |
| for _ in range(deep_layers): |
| layers.extend([nn.Linear(in_dim, deep_dim), nn.ReLU(), nn.Dropout(dropout)]) |
| in_dim = deep_dim |
| self.deep_network = nn.Sequential(*layers) |
| self.output_dim = deep_dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x0 = x |
| cross_out = x |
| for cross_layer in self.cross_layers: |
| cross_out = cross_layer(x0, cross_out) |
| return self.deep_network(cross_out) |
|
|
|
|
| class JointFusionModel(nn.Module): |
| """nuFormer-style joint fusion: Transaction Transformer + DCNv2(PLR). |
| |
| Architecture: |
| Transaction Sequence -> Pre-trained DomainTransformer -> user_embedding |
| Tabular Features -> PLR -> flatten -> DCNv2 -> tab_embedding |
| Concatenate -> MLP Head -> prediction |
| """ |
| def __init__(self, transformer_model, n_tabular_features: int, n_classes: int = 1, |
| plr_frequencies: int = 64, plr_embedding_dim: int = 64, |
| dcn_cross_layers: int = 3, dcn_deep_layers: int = 2, |
| dcn_deep_dim: int = 256, head_hidden_dim: int = 256, dropout: float = 0.1): |
| super().__init__() |
| self.transformer = transformer_model |
| transformer_dim = transformer_model.config.hidden_size |
|
|
| self.plr = PeriodicLinearReLU(n_features=n_tabular_features, |
| n_frequencies=plr_frequencies, embedding_dim=plr_embedding_dim) |
| plr_flat_dim = n_tabular_features * plr_embedding_dim |
| self.dcn = DCNv2(input_dim=plr_flat_dim, cross_layers=dcn_cross_layers, |
| deep_layers=dcn_deep_layers, deep_dim=dcn_deep_dim, dropout=dropout) |
|
|
| combined_dim = transformer_dim + dcn_deep_dim |
| self.head = nn.Sequential( |
| nn.Linear(combined_dim, head_hidden_dim), nn.ReLU(), |
| nn.Dropout(dropout), nn.Linear(head_hidden_dim, n_classes), |
| ) |
| self.n_classes = n_classes |
|
|
| def forward(self, input_ids, attention_mask=None, tabular_features=None, labels=None): |
| user_embedding = self.transformer.get_user_embedding(input_ids=input_ids, attention_mask=attention_mask) |
| tab_embedded = self.plr(tabular_features) |
| tab_flat = tab_embedded.reshape(tab_embedded.size(0), -1) |
| tab_output = self.dcn(tab_flat) |
|
|
| combined = torch.cat([user_embedding, tab_output], dim=-1) |
| logits = self.head(combined) |
|
|
| loss = None |
| if labels is not None: |
| if self.n_classes == 1: |
| loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) |
| else: |
| loss = F.cross_entropy(logits, labels.long()) |
|
|
| return {"loss": loss, "logits": logits} |
|
|