""" Joint Fusion Model — nuFormer-style architecture combining: 1. Transaction Transformer (pre-trained DomainTransformer) -> user embedding 2. DCNv2 with PLR embeddings (tabular features) -> feature embedding 3. Shared MLP head -> prediction Architecture follows Nubank nuFormer (arXiv:2507.23267). This is the fine-tuning architecture. Pre-training uses DomainTransformerForCausalLM alone. """ from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F from .plr_embeddings import PeriodicLinearReLU class DCNv2CrossLayer(nn.Module): """Single cross layer from Deep & Cross Network V2 (Wang et al. 2021). Computes: x_{l+1} = x_0 * (W_l * x_l + b_l) + x_l """ def __init__(self, dim: int): super().__init__() self.weight = nn.Linear(dim, dim, bias=True) def forward(self, x0: torch.Tensor, x: torch.Tensor) -> torch.Tensor: return x0 * self.weight(x) + x class DCNv2(nn.Module): """Deep & Cross Network V2 for tabular feature interaction.""" def __init__(self, input_dim: int, cross_layers: int = 3, deep_layers: int = 2, deep_dim: int = 256, dropout: float = 0.1): super().__init__() self.cross_layers = nn.ModuleList([DCNv2CrossLayer(input_dim) for _ in range(cross_layers)]) layers = [] in_dim = input_dim for _ in range(deep_layers): layers.extend([nn.Linear(in_dim, deep_dim), nn.ReLU(), nn.Dropout(dropout)]) in_dim = deep_dim self.deep_network = nn.Sequential(*layers) self.output_dim = deep_dim def forward(self, x: torch.Tensor) -> torch.Tensor: x0 = x cross_out = x for cross_layer in self.cross_layers: cross_out = cross_layer(x0, cross_out) return self.deep_network(cross_out) class JointFusionModel(nn.Module): """nuFormer-style joint fusion: Transaction Transformer + DCNv2(PLR). Architecture: Transaction Sequence -> Pre-trained DomainTransformer -> user_embedding Tabular Features -> PLR -> flatten -> DCNv2 -> tab_embedding Concatenate -> MLP Head -> prediction """ def __init__(self, transformer_model, n_tabular_features: int, n_classes: int = 1, plr_frequencies: int = 64, plr_embedding_dim: int = 64, dcn_cross_layers: int = 3, dcn_deep_layers: int = 2, dcn_deep_dim: int = 256, head_hidden_dim: int = 256, dropout: float = 0.1): super().__init__() self.transformer = transformer_model transformer_dim = transformer_model.config.hidden_size self.plr = PeriodicLinearReLU(n_features=n_tabular_features, n_frequencies=plr_frequencies, embedding_dim=plr_embedding_dim) plr_flat_dim = n_tabular_features * plr_embedding_dim self.dcn = DCNv2(input_dim=plr_flat_dim, cross_layers=dcn_cross_layers, deep_layers=dcn_deep_layers, deep_dim=dcn_deep_dim, dropout=dropout) combined_dim = transformer_dim + dcn_deep_dim self.head = nn.Sequential( nn.Linear(combined_dim, head_hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(head_hidden_dim, n_classes), ) self.n_classes = n_classes def forward(self, input_ids, attention_mask=None, tabular_features=None, labels=None): user_embedding = self.transformer.get_user_embedding(input_ids=input_ids, attention_mask=attention_mask) tab_embedded = self.plr(tabular_features) tab_flat = tab_embedded.reshape(tab_embedded.size(0), -1) tab_output = self.dcn(tab_flat) combined = torch.cat([user_embedding, tab_output], dim=-1) logits = self.head(combined) loss = None if labels is not None: if self.n_classes == 1: loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float()) else: loss = F.cross_entropy(logits, labels.long()) return {"loss": loss, "logits": logits}