File size: 4,096 Bytes
e881ea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Joint Fusion Model — nuFormer-style architecture combining:
  1. Transaction Transformer (pre-trained DomainTransformer) -> user embedding
  2. DCNv2 with PLR embeddings (tabular features) -> feature embedding
  3. Shared MLP head -> prediction

Architecture follows Nubank nuFormer (arXiv:2507.23267).
This is the fine-tuning architecture. Pre-training uses DomainTransformerForCausalLM alone.
"""

from typing import Dict, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .plr_embeddings import PeriodicLinearReLU


class DCNv2CrossLayer(nn.Module):
    """Single cross layer from Deep & Cross Network V2 (Wang et al. 2021).
    Computes: x_{l+1} = x_0 * (W_l * x_l + b_l) + x_l
    """
    def __init__(self, dim: int):
        super().__init__()
        self.weight = nn.Linear(dim, dim, bias=True)

    def forward(self, x0: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        return x0 * self.weight(x) + x


class DCNv2(nn.Module):
    """Deep & Cross Network V2 for tabular feature interaction."""
    def __init__(self, input_dim: int, cross_layers: int = 3, deep_layers: int = 2,
                 deep_dim: int = 256, dropout: float = 0.1):
        super().__init__()
        self.cross_layers = nn.ModuleList([DCNv2CrossLayer(input_dim) for _ in range(cross_layers)])
        layers = []
        in_dim = input_dim
        for _ in range(deep_layers):
            layers.extend([nn.Linear(in_dim, deep_dim), nn.ReLU(), nn.Dropout(dropout)])
            in_dim = deep_dim
        self.deep_network = nn.Sequential(*layers)
        self.output_dim = deep_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = x
        cross_out = x
        for cross_layer in self.cross_layers:
            cross_out = cross_layer(x0, cross_out)
        return self.deep_network(cross_out)


class JointFusionModel(nn.Module):
    """nuFormer-style joint fusion: Transaction Transformer + DCNv2(PLR).

    Architecture:
        Transaction Sequence -> Pre-trained DomainTransformer -> user_embedding
        Tabular Features -> PLR -> flatten -> DCNv2 -> tab_embedding
        Concatenate -> MLP Head -> prediction
    """
    def __init__(self, transformer_model, n_tabular_features: int, n_classes: int = 1,
                 plr_frequencies: int = 64, plr_embedding_dim: int = 64,
                 dcn_cross_layers: int = 3, dcn_deep_layers: int = 2,
                 dcn_deep_dim: int = 256, head_hidden_dim: int = 256, dropout: float = 0.1):
        super().__init__()
        self.transformer = transformer_model
        transformer_dim = transformer_model.config.hidden_size

        self.plr = PeriodicLinearReLU(n_features=n_tabular_features,
                                       n_frequencies=plr_frequencies, embedding_dim=plr_embedding_dim)
        plr_flat_dim = n_tabular_features * plr_embedding_dim
        self.dcn = DCNv2(input_dim=plr_flat_dim, cross_layers=dcn_cross_layers,
                         deep_layers=dcn_deep_layers, deep_dim=dcn_deep_dim, dropout=dropout)

        combined_dim = transformer_dim + dcn_deep_dim
        self.head = nn.Sequential(
            nn.Linear(combined_dim, head_hidden_dim), nn.ReLU(),
            nn.Dropout(dropout), nn.Linear(head_hidden_dim, n_classes),
        )
        self.n_classes = n_classes

    def forward(self, input_ids, attention_mask=None, tabular_features=None, labels=None):
        user_embedding = self.transformer.get_user_embedding(input_ids=input_ids, attention_mask=attention_mask)
        tab_embedded = self.plr(tabular_features)
        tab_flat = tab_embedded.reshape(tab_embedded.size(0), -1)
        tab_output = self.dcn(tab_flat)

        combined = torch.cat([user_embedding, tab_output], dim=-1)
        logits = self.head(combined)

        loss = None
        if labels is not None:
            if self.n_classes == 1:
                loss = F.binary_cross_entropy_with_logits(logits.squeeze(-1), labels.float())
            else:
                loss = F.cross_entropy(logits, labels.long())

        return {"loss": loss, "logits": logits}