from math import sqrt import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.attention import SDPBackend, sdpa_kernel import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.attention import SDPBackend, sdpa_kernel class HopfieldReLU(nn.Module): """ Hopfield ReLU block. Forward: - ReLU(x W^T + b) @ W Energy: -0.5 * sum ReLU(x W^T + b)^2 """ def __init__( self, embedding_dim, nmemories, bias=True, device=None, dropout=0.0, initializer_range=0.002, ): super().__init__() self.initializer_range = float(initializer_range) self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device)) self.dropout = nn.Dropout(dropout) if bias: self.bias = nn.Parameter(torch.empty(nmemories, device=device)) else: self.bias = None with torch.no_grad(): self.W.normal_(mean=0.0, std=self.initializer_range) if self.bias is not None: self.bias.zero_() def forward(self, x): """ x: (B, S, D) or (S, D) returns: same shape as x """ squeeze_back = False if x.dim() == 2: x = x.unsqueeze(0) squeeze_back = True H = x @ self.W.t() if self.bias is not None: H = H + self.bias.view(1, 1, -1) out = -(torch.relu(H) @ self.W) out = self.dropout(out) if squeeze_back: out = out.squeeze(0) return out def energy(self, x): """ Scalar energy summed over batch and tokens. """ if x.dim() == 1: x = x.unsqueeze(0).unsqueeze(0) elif x.dim() == 2: x = x.unsqueeze(0) H = x @ self.W.t() if self.bias is not None: H = H + self.bias.view(1, 1, -1) return -0.5 * (torch.relu(H) ** 2).sum() class HopfieldSoftmax(nn.Module): """ Hopfield Softmax block (energy-based CHNSoftmax analogue). Energy: E(x) = -(1 / beta) * sum logsumexp(beta * (x W^T + b)) Forward: - softmax(beta * (x W^T + b)) @ W """ def __init__( self, embedding_dim, nmemories, beta=1.0, bias=True, device=None, dropout=0.0, initializer_range=0.002, ): super().__init__() self.beta = float(beta) self.initializer_range = float(initializer_range) self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device)) self.dropout = nn.Dropout(dropout) if bias: self.bias = nn.Parameter(torch.empty(nmemories, device=device)) else: self.bias = None with torch.no_grad(): self.W.normal_(mean=0.0, std=self.initializer_range) if self.bias is not None: self.bias.zero_() def forward(self, x): """ x: (D,) or (S, D) or (B, S, D) returns: same shape as x """ squeeze_1d = False squeeze_2d = False if x.dim() == 1: x = x.unsqueeze(0).unsqueeze(0) squeeze_1d = True elif x.dim() == 2: x = x.unsqueeze(0) squeeze_2d = True logits = self.beta * (x @ self.W.t()) if self.bias is not None: logits = logits + (self.beta * self.bias).view(1, 1, -1) A = torch.softmax(logits, dim=-1) out = -(A @ self.W) out = self.dropout(out) if squeeze_1d: return out.squeeze(0).squeeze(0) if squeeze_2d: return out.squeeze(0) return out def energy(self, x): """ Scalar energy summed over batch and tokens. """ if x.dim() == 1: x = x.unsqueeze(0).unsqueeze(0) elif x.dim() == 2: x = x.unsqueeze(0) logits = self.beta * (x @ self.W.t()) if self.bias is not None: logits = logits + (self.beta * self.bias).view(1, 1, -1) return -(1.0 / self.beta) * torch.logsumexp(logits, dim=-1).sum() class HopfieldMHA(nn.Module): """ Multi-head attention update corresponding to the gradient of the ET attention energy. Conventions: - input x: (B, S, D) or (S, D) - attention_mask: (B, S), True/1 = keep, False/0 = pad """ def __init__( self, embedding_dim, nheads, beta=None, device=None, dropout=0.0, initializer_range=0.002, ): super().__init__() if embedding_dim % nheads != 0: raise ValueError( f"embedding_dim ({embedding_dim}) must be divisible by nheads ({nheads})." ) self.nheads = nheads self.head_dim = embedding_dim // nheads if beta is None: self.beta = 1.0 / (self.head_dim ** 0.5) else: self.beta = float(beta) self.initializer_range = float(initializer_range) self.dropout = nn.Dropout(dropout) self.Wq = nn.Parameter( torch.empty(nheads, embedding_dim, self.head_dim, device=device) ) self.Wk = nn.Parameter( torch.empty(nheads, embedding_dim, self.head_dim, device=device) ) with torch.no_grad(): self.Wq.normal_(mean=0.0, std=self.initializer_range) self.Wk.normal_(mean=0.0, std=self.initializer_range) def _ensure_batch_dim(self, x): squeeze_back = False if x.dim() == 2: x = x.unsqueeze(0) squeeze_back = True return x, squeeze_back def _resolve_keep_mask(self, attention_mask=None): if attention_mask is not None: return attention_mask.to(torch.bool) return None def _no_self_mask(self, seq_len, device): return torch.eye(seq_len, dtype=torch.bool, device=device) def forward( self, x, attention_mask=None, allow_self=True, ): """ Returns: tensor of shape (B, S, D) or (S, D) """ x, squeeze_back = self._ensure_batch_dim(x) B, S, _ = x.shape keep = self._resolve_keep_mask(attention_mask) Q = x.unsqueeze(1) @ self.Wq K = x.unsqueeze(1) @ self.Wk WqT = self.Wq.transpose(-1, -2) WkT = self.Wk.transpose(-1, -2) V1 = K @ WqT V2 = Q @ WkT # SDPA computes softmax(QK^T / sqrt(dk)). # We want softmax(beta * QK^T), so we rescale Q by beta * sqrt(dk) # in order to match the same logits as in energy() and in the exact T2 path. dk = Q.shape[-1] q_scale = self.beta * (dk ** 0.5) Qs = Q * q_scale sdpa_mask = None if keep is not None: sdpa_mask = keep.view(B, 1, 1, S) if not allow_self: no_self = ~self._no_self_mask(S, x.device) no_self = no_self.view(1, 1, S, S) sdpa_mask = no_self if sdpa_mask is None else (sdpa_mask & no_self) with sdpa_kernel( [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH] ): T1 = -F.scaled_dot_product_attention( Qs, K, V1, attn_mask=sdpa_mask, dropout_p=self.dropout.p if self.training else 0.0, is_causal=False, ) logits = self.beta * (Q @ K.transpose(-2, -1)) neg_inf = -float("inf") if keep is not None: logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf) if not allow_self: logits = logits.masked_fill( self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf ) A = torch.softmax(logits.float(), dim=-1).to(logits.dtype) A = self.dropout(A) T2 = -(A.transpose(-2, -1) @ V2) out = (T1 + T2).sum(dim=1) if squeeze_back: out = out.squeeze(0) return out def energy( self, x, attention_mask=None, allow_self=True, ): """ Scalar energy summed over batch, heads and tokens. E = -(1 / beta) * sum logsumexp(beta * + mask) """ x, _ = self._ensure_batch_dim(x) B, S, _ = x.shape keep = self._resolve_keep_mask(attention_mask) Q = x.unsqueeze(1) @ self.Wq K = x.unsqueeze(1) @ self.Wk logits = self.beta * (Q @ K.transpose(-2, -1)) neg_inf = -float("inf") if keep is not None: logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf) if not allow_self: logits = logits.masked_fill( self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf ) lse = torch.logsumexp(logits, dim=-1) return -(1.0 / self.beta) * lse.sum() class HopfieldLayer(nn.Module): def __init__( self, embedding_dim, nheads, forward_memories, forward_activation="relu", beta=1.0, bias=True, device=None, dropout=0.0, initializer_range=0.002, ): super().__init__() if forward_activation == "relu": self.ffn = HopfieldReLU( embedding_dim=embedding_dim, nmemories=forward_memories, bias=bias, device=device, dropout=dropout, initializer_range=initializer_range, ) elif forward_activation == "softmax": self.ffn = HopfieldSoftmax( embedding_dim=embedding_dim, nmemories=forward_memories, beta=beta, bias=bias, device=device, dropout=dropout, initializer_range=initializer_range, ) else: raise ValueError( f"Not implemented forward_activation='{forward_activation}'. " "Expected one of: 'relu', 'softmax'." ) self.mha = HopfieldMHA( embedding_dim=embedding_dim, nheads=nheads, beta=beta, device=device, dropout=dropout, initializer_range=initializer_range, ) def energy( self, x, attention_mask=None, allow_self=True, ): return self.mha.energy( x, attention_mask=attention_mask, allow_self=allow_self, ) + self.ffn.energy(x) def forward( self, x, attention_mask=None, allow_self=True, ): ffn_output = self.ffn(x) mha_output = self.mha( x, attention_mask=attention_mask, allow_self=allow_self, ) return ffn_output + mha_output if __name__ == "__main__": def grad_energy(module, x, **kwargs): x = x.clone().detach().requires_grad_(True) return torch.func.grad(lambda z: module.energy(z, **kwargs))(x) def check_module(name, module, x, atol=1e-5, rtol=1e-4, **kwargs): module.eval() with torch.no_grad(): forward_out = module(x, **kwargs) grad_out = grad_energy(module, x, **kwargs) ok = torch.allclose(forward_out, grad_out, atol=atol, rtol=rtol) max_diff = (forward_out - grad_out).abs().max().item() print(f"\n=== {name} ===") print("forward shape :", tuple(forward_out.shape)) print("grad shape :", tuple(grad_out.shape)) print("allclose :", ok) print("max abs diff :", max_diff) device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(0) x = torch.randn(2, 4, 12, device=device) keep_mask = torch.tensor( [[1, 1, 1, 0], [1, 1, 0, 0]], dtype=torch.bool, device=device, ) print("Testing elementary blocks against grad(E)...") hrelu = HopfieldReLU(12, 8, bias=False, device=device, dropout=0.0) check_module("HopfieldReLU", hrelu, x) hsoftmax = HopfieldSoftmax(12, 8, beta=1.0, bias=False, device=device, dropout=0.0) check_module("HopfieldSoftmax", hsoftmax, x) mha = HopfieldMHA(12, 3, beta=1.0, device=device, dropout=0.0) check_module("HopfieldMHA", mha, x, attention_mask=keep_mask) layer = HopfieldLayer( embedding_dim=12, nheads=3, forward_memories=16, forward_activation="relu", beta=1.0, bias=True, device=device, dropout=0.0, ) check_module("HopfieldLayer", layer, x, attention_mask=keep_mask) print("\nTesting original-style external normalization...") norm = nn.LayerNorm(12).to(device).eval() g = norm(x) with torch.no_grad(): update = layer(g, attention_mask=keep_mask) g_req = g.clone().detach().requires_grad_(True) grad_g = torch.func.grad(lambda z: layer.energy(z, attention_mask=keep_mask))(g_req) print("\n=== HopfieldLayer on normalized input g ===") print("allclose :", torch.allclose(update, grad_g, atol=1e-5, rtol=1e-4)) print("max abs diff :", (update - grad_g).abs().max().item())