| """ |
| PLR (Periodic Linear ReLU) Numerical Embeddings. |
| |
| Maps scalar numerical features to high-dimensional dense vectors via |
| learned periodic (sin/cos) activations followed by a linear projection. |
| |
| From: Gorishniy et al. 2022, "On Embeddings for Numerical Features in |
| Tabular Deep Learning" (arXiv:2203.05556, NeurIPS 2022). |
| |
| Used by Nubank nuFormer for the tabular feature branch (291 features). |
| PLR is the ingredient that makes DCNv2 beat LightGBM. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class PeriodicLinearReLU(nn.Module): |
| """PLR numerical embeddings (Gorishniy et al. 2022). |
| |
| Maps each scalar feature through learned periodic activations: |
| x -> [sin(2pi*w*x + b), cos(2pi*w*x + b)] -> Linear -> ReLU |
| |
| Frequencies w and phases b are LEARNED parameters (per feature). |
| |
| Args: |
| n_features: Number of numerical features. |
| n_frequencies: Number of sin/cos frequency pairs per feature. |
| embedding_dim: Output embedding dimension per feature. |
| |
| Input: (batch, n_features) -- raw scalar feature values |
| Output: (batch, n_features, embedding_dim) |
| """ |
|
|
| def __init__(self, n_features: int, n_frequencies: int = 64, embedding_dim: int = 64): |
| super().__init__() |
| self.n_features = n_features |
| self.n_frequencies = n_frequencies |
| self.embedding_dim = embedding_dim |
|
|
| self.frequencies = nn.Parameter(torch.randn(n_features, n_frequencies) * 0.01) |
| self.phases = nn.Parameter(torch.zeros(n_features, n_frequencies)) |
| self.linear = nn.Linear(2 * n_frequencies, embedding_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x.unsqueeze(-1) |
| angles = 2 * math.pi * self.frequencies.unsqueeze(0) * x + self.phases.unsqueeze(0) |
| periodic = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) |
| return torch.relu(self.linear(periodic)) |
|
|
| def extra_repr(self) -> str: |
| return (f"n_features={self.n_features}, n_frequencies={self.n_frequencies}, " |
| f"embedding_dim={self.embedding_dim}") |
|
|