| from math import sqrt |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
|
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
|
| class HopfieldReLU(nn.Module): |
| """ |
| Hopfield ReLU block. |
| |
| Forward: |
| - ReLU(x W^T + b) @ W |
| |
| Energy: |
| -0.5 * sum ReLU(x W^T + b)^2 |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim, |
| nmemories, |
| bias=True, |
| device=None, |
| dropout=0.0, |
| initializer_range=0.002, |
| ): |
| super().__init__() |
| self.initializer_range = float(initializer_range) |
|
|
| self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device)) |
| self.dropout = nn.Dropout(dropout) |
|
|
| if bias: |
| self.bias = nn.Parameter(torch.empty(nmemories, device=device)) |
| else: |
| self.bias = None |
|
|
| with torch.no_grad(): |
| self.W.normal_(mean=0.0, std=self.initializer_range) |
| if self.bias is not None: |
| self.bias.zero_() |
|
|
| def forward(self, x): |
| """ |
| x: (B, S, D) or (S, D) |
| returns: same shape as x |
| """ |
| squeeze_back = False |
| if x.dim() == 2: |
| x = x.unsqueeze(0) |
| squeeze_back = True |
|
|
| H = x @ self.W.t() |
| if self.bias is not None: |
| H = H + self.bias.view(1, 1, -1) |
|
|
| out = -(torch.relu(H) @ self.W) |
| out = self.dropout(out) |
|
|
| if squeeze_back: |
| out = out.squeeze(0) |
| return out |
|
|
| def energy(self, x): |
| """ |
| Scalar energy summed over batch and tokens. |
| """ |
| if x.dim() == 1: |
| x = x.unsqueeze(0).unsqueeze(0) |
| elif x.dim() == 2: |
| x = x.unsqueeze(0) |
|
|
| H = x @ self.W.t() |
| if self.bias is not None: |
| H = H + self.bias.view(1, 1, -1) |
|
|
| return -0.5 * (torch.relu(H) ** 2).sum() |
|
|
|
|
| class HopfieldSoftmax(nn.Module): |
| """ |
| Hopfield Softmax block (energy-based CHNSoftmax analogue). |
| |
| Energy: |
| E(x) = -(1 / beta) * sum logsumexp(beta * (x W^T + b)) |
| |
| Forward: |
| - softmax(beta * (x W^T + b)) @ W |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim, |
| nmemories, |
| beta=1.0, |
| bias=True, |
| device=None, |
| dropout=0.0, |
| initializer_range=0.002, |
| ): |
| super().__init__() |
| self.beta = float(beta) |
| self.initializer_range = float(initializer_range) |
|
|
| self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device)) |
| self.dropout = nn.Dropout(dropout) |
|
|
| if bias: |
| self.bias = nn.Parameter(torch.empty(nmemories, device=device)) |
| else: |
| self.bias = None |
|
|
| with torch.no_grad(): |
| self.W.normal_(mean=0.0, std=self.initializer_range) |
| if self.bias is not None: |
| self.bias.zero_() |
|
|
| def forward(self, x): |
| """ |
| x: (D,) or (S, D) or (B, S, D) |
| returns: same shape as x |
| """ |
| squeeze_1d = False |
| squeeze_2d = False |
|
|
| if x.dim() == 1: |
| x = x.unsqueeze(0).unsqueeze(0) |
| squeeze_1d = True |
| elif x.dim() == 2: |
| x = x.unsqueeze(0) |
| squeeze_2d = True |
|
|
| logits = self.beta * (x @ self.W.t()) |
| if self.bias is not None: |
| logits = logits + (self.beta * self.bias).view(1, 1, -1) |
|
|
| A = torch.softmax(logits, dim=-1) |
| out = -(A @ self.W) |
| out = self.dropout(out) |
|
|
| if squeeze_1d: |
| return out.squeeze(0).squeeze(0) |
| if squeeze_2d: |
| return out.squeeze(0) |
| return out |
|
|
| def energy(self, x): |
| """ |
| Scalar energy summed over batch and tokens. |
| """ |
| if x.dim() == 1: |
| x = x.unsqueeze(0).unsqueeze(0) |
| elif x.dim() == 2: |
| x = x.unsqueeze(0) |
|
|
| logits = self.beta * (x @ self.W.t()) |
| if self.bias is not None: |
| logits = logits + (self.beta * self.bias).view(1, 1, -1) |
|
|
| return -(1.0 / self.beta) * torch.logsumexp(logits, dim=-1).sum() |
|
|
|
|
| class HopfieldMHA(nn.Module): |
| """ |
| Multi-head attention update corresponding to the gradient of the ET attention energy. |
| |
| Conventions: |
| - input x: (B, S, D) or (S, D) |
| - attention_mask: (B, S), True/1 = keep, False/0 = pad |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim, |
| nheads, |
| beta=None, |
| device=None, |
| dropout=0.0, |
| initializer_range=0.002, |
| ): |
| super().__init__() |
|
|
| if embedding_dim % nheads != 0: |
| raise ValueError( |
| f"embedding_dim ({embedding_dim}) must be divisible by nheads ({nheads})." |
| ) |
|
|
| self.nheads = nheads |
| self.head_dim = embedding_dim // nheads |
|
|
| if beta is None: |
| self.beta = 1.0 / (self.head_dim ** 0.5) |
| else: |
| self.beta = float(beta) |
|
|
| self.initializer_range = float(initializer_range) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self.Wq = nn.Parameter( |
| torch.empty(nheads, embedding_dim, self.head_dim, device=device) |
| ) |
| self.Wk = nn.Parameter( |
| torch.empty(nheads, embedding_dim, self.head_dim, device=device) |
| ) |
|
|
| with torch.no_grad(): |
| self.Wq.normal_(mean=0.0, std=self.initializer_range) |
| self.Wk.normal_(mean=0.0, std=self.initializer_range) |
|
|
| def _ensure_batch_dim(self, x): |
| squeeze_back = False |
| if x.dim() == 2: |
| x = x.unsqueeze(0) |
| squeeze_back = True |
| return x, squeeze_back |
|
|
| def _resolve_keep_mask(self, attention_mask=None): |
| if attention_mask is not None: |
| return attention_mask.to(torch.bool) |
| return None |
|
|
| def _no_self_mask(self, seq_len, device): |
| return torch.eye(seq_len, dtype=torch.bool, device=device) |
|
|
| def forward( |
| self, |
| x, |
| attention_mask=None, |
| allow_self=True, |
| ): |
| """ |
| Returns: |
| tensor of shape (B, S, D) or (S, D) |
| """ |
| x, squeeze_back = self._ensure_batch_dim(x) |
| B, S, _ = x.shape |
|
|
| keep = self._resolve_keep_mask(attention_mask) |
|
|
| Q = x.unsqueeze(1) @ self.Wq |
| K = x.unsqueeze(1) @ self.Wk |
|
|
| WqT = self.Wq.transpose(-1, -2) |
| WkT = self.Wk.transpose(-1, -2) |
|
|
| V1 = K @ WqT |
| V2 = Q @ WkT |
|
|
| |
| |
| |
| dk = Q.shape[-1] |
| q_scale = self.beta * (dk ** 0.5) |
| Qs = Q * q_scale |
|
|
| sdpa_mask = None |
| if keep is not None: |
| sdpa_mask = keep.view(B, 1, 1, S) |
|
|
| if not allow_self: |
| no_self = ~self._no_self_mask(S, x.device) |
| no_self = no_self.view(1, 1, S, S) |
| sdpa_mask = no_self if sdpa_mask is None else (sdpa_mask & no_self) |
|
|
| with sdpa_kernel( |
| [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH] |
| ): |
| T1 = -F.scaled_dot_product_attention( |
| Qs, |
| K, |
| V1, |
| attn_mask=sdpa_mask, |
| dropout_p=self.dropout.p if self.training else 0.0, |
| is_causal=False, |
| ) |
|
|
| logits = self.beta * (Q @ K.transpose(-2, -1)) |
| neg_inf = -float("inf") |
|
|
| if keep is not None: |
| logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf) |
|
|
| if not allow_self: |
| logits = logits.masked_fill( |
| self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf |
| ) |
|
|
| A = torch.softmax(logits.float(), dim=-1).to(logits.dtype) |
| A = self.dropout(A) |
|
|
| T2 = -(A.transpose(-2, -1) @ V2) |
| out = (T1 + T2).sum(dim=1) |
|
|
| if squeeze_back: |
| out = out.squeeze(0) |
| return out |
|
|
| def energy( |
| self, |
| x, |
| attention_mask=None, |
| allow_self=True, |
| ): |
| """ |
| Scalar energy summed over batch, heads and tokens. |
| |
| E = -(1 / beta) * sum logsumexp(beta * <Q, K> + mask) |
| """ |
| x, _ = self._ensure_batch_dim(x) |
| B, S, _ = x.shape |
|
|
| keep = self._resolve_keep_mask(attention_mask) |
|
|
| Q = x.unsqueeze(1) @ self.Wq |
| K = x.unsqueeze(1) @ self.Wk |
|
|
| logits = self.beta * (Q @ K.transpose(-2, -1)) |
| neg_inf = -float("inf") |
|
|
| if keep is not None: |
| logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf) |
|
|
| if not allow_self: |
| logits = logits.masked_fill( |
| self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf |
| ) |
|
|
| lse = torch.logsumexp(logits, dim=-1) |
| return -(1.0 / self.beta) * lse.sum() |
|
|
|
|
| class HopfieldLayer(nn.Module): |
| def __init__( |
| self, |
| embedding_dim, |
| nheads, |
| forward_memories, |
| forward_activation="relu", |
| beta=1.0, |
| bias=True, |
| device=None, |
| dropout=0.0, |
| initializer_range=0.002, |
| ): |
| super().__init__() |
|
|
| if forward_activation == "relu": |
| self.ffn = HopfieldReLU( |
| embedding_dim=embedding_dim, |
| nmemories=forward_memories, |
| bias=bias, |
| device=device, |
| dropout=dropout, |
| initializer_range=initializer_range, |
| ) |
| elif forward_activation == "softmax": |
| self.ffn = HopfieldSoftmax( |
| embedding_dim=embedding_dim, |
| nmemories=forward_memories, |
| beta=beta, |
| bias=bias, |
| device=device, |
| dropout=dropout, |
| initializer_range=initializer_range, |
| ) |
| else: |
| raise ValueError( |
| f"Not implemented forward_activation='{forward_activation}'. " |
| "Expected one of: 'relu', 'softmax'." |
| ) |
|
|
| self.mha = HopfieldMHA( |
| embedding_dim=embedding_dim, |
| nheads=nheads, |
| beta=beta, |
| device=device, |
| dropout=dropout, |
| initializer_range=initializer_range, |
| ) |
|
|
| def energy( |
| self, |
| x, |
| attention_mask=None, |
| allow_self=True, |
| ): |
| return self.mha.energy( |
| x, |
| attention_mask=attention_mask, |
| allow_self=allow_self, |
| ) + self.ffn.energy(x) |
|
|
| def forward( |
| self, |
| x, |
| attention_mask=None, |
| allow_self=True, |
| ): |
| ffn_output = self.ffn(x) |
| mha_output = self.mha( |
| x, |
| attention_mask=attention_mask, |
| allow_self=allow_self, |
| ) |
| return ffn_output + mha_output |
|
|
|
|
| if __name__ == "__main__": |
|
|
| def grad_energy(module, x, **kwargs): |
| x = x.clone().detach().requires_grad_(True) |
| return torch.func.grad(lambda z: module.energy(z, **kwargs))(x) |
|
|
| def check_module(name, module, x, atol=1e-5, rtol=1e-4, **kwargs): |
| module.eval() |
| with torch.no_grad(): |
| forward_out = module(x, **kwargs) |
| grad_out = grad_energy(module, x, **kwargs) |
|
|
| ok = torch.allclose(forward_out, grad_out, atol=atol, rtol=rtol) |
| max_diff = (forward_out - grad_out).abs().max().item() |
|
|
| print(f"\n=== {name} ===") |
| print("forward shape :", tuple(forward_out.shape)) |
| print("grad shape :", tuple(grad_out.shape)) |
| print("allclose :", ok) |
| print("max abs diff :", max_diff) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.manual_seed(0) |
|
|
| x = torch.randn(2, 4, 12, device=device) |
| keep_mask = torch.tensor( |
| [[1, 1, 1, 0], |
| [1, 1, 0, 0]], |
| dtype=torch.bool, |
| device=device, |
| ) |
|
|
| print("Testing elementary blocks against grad(E)...") |
|
|
| hrelu = HopfieldReLU(12, 8, bias=False, device=device, dropout=0.0) |
| check_module("HopfieldReLU", hrelu, x) |
|
|
| hsoftmax = HopfieldSoftmax(12, 8, beta=1.0, bias=False, device=device, dropout=0.0) |
| check_module("HopfieldSoftmax", hsoftmax, x) |
|
|
| mha = HopfieldMHA(12, 3, beta=1.0, device=device, dropout=0.0) |
| check_module("HopfieldMHA", mha, x, attention_mask=keep_mask) |
|
|
| layer = HopfieldLayer( |
| embedding_dim=12, |
| nheads=3, |
| forward_memories=16, |
| forward_activation="relu", |
| beta=1.0, |
| bias=True, |
| device=device, |
| dropout=0.0, |
| ) |
| check_module("HopfieldLayer", layer, x, attention_mask=keep_mask) |
|
|
| print("\nTesting original-style external normalization...") |
| norm = nn.LayerNorm(12).to(device).eval() |
| g = norm(x) |
|
|
| with torch.no_grad(): |
| update = layer(g, attention_mask=keep_mask) |
|
|
| g_req = g.clone().detach().requires_grad_(True) |
| grad_g = torch.func.grad(lambda z: layer.energy(z, attention_mask=keep_mask))(g_req) |
|
|
| print("\n=== HopfieldLayer on normalized input g ===") |
| print("allclose :", torch.allclose(update, grad_g, atol=1e-5, rtol=1e-4)) |
| print("max abs diff :", (update - grad_g).abs().max().item()) |