HOP4NLP / hopfield.py
TCMVince's picture
Upload hopfield.py with huggingface_hub
d972a70 verified
from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
class HopfieldReLU(nn.Module):
"""
Hopfield ReLU block.
Forward:
- ReLU(x W^T + b) @ W
Energy:
-0.5 * sum ReLU(x W^T + b)^2
"""
def __init__(
self,
embedding_dim,
nmemories,
bias=True,
device=None,
dropout=0.0,
initializer_range=0.002,
):
super().__init__()
self.initializer_range = float(initializer_range)
self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device))
self.dropout = nn.Dropout(dropout)
if bias:
self.bias = nn.Parameter(torch.empty(nmemories, device=device))
else:
self.bias = None
with torch.no_grad():
self.W.normal_(mean=0.0, std=self.initializer_range)
if self.bias is not None:
self.bias.zero_()
def forward(self, x):
"""
x: (B, S, D) or (S, D)
returns: same shape as x
"""
squeeze_back = False
if x.dim() == 2:
x = x.unsqueeze(0)
squeeze_back = True
H = x @ self.W.t()
if self.bias is not None:
H = H + self.bias.view(1, 1, -1)
out = -(torch.relu(H) @ self.W)
out = self.dropout(out)
if squeeze_back:
out = out.squeeze(0)
return out
def energy(self, x):
"""
Scalar energy summed over batch and tokens.
"""
if x.dim() == 1:
x = x.unsqueeze(0).unsqueeze(0)
elif x.dim() == 2:
x = x.unsqueeze(0)
H = x @ self.W.t()
if self.bias is not None:
H = H + self.bias.view(1, 1, -1)
return -0.5 * (torch.relu(H) ** 2).sum()
class HopfieldSoftmax(nn.Module):
"""
Hopfield Softmax block (energy-based CHNSoftmax analogue).
Energy:
E(x) = -(1 / beta) * sum logsumexp(beta * (x W^T + b))
Forward:
- softmax(beta * (x W^T + b)) @ W
"""
def __init__(
self,
embedding_dim,
nmemories,
beta=1.0,
bias=True,
device=None,
dropout=0.0,
initializer_range=0.002,
):
super().__init__()
self.beta = float(beta)
self.initializer_range = float(initializer_range)
self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device))
self.dropout = nn.Dropout(dropout)
if bias:
self.bias = nn.Parameter(torch.empty(nmemories, device=device))
else:
self.bias = None
with torch.no_grad():
self.W.normal_(mean=0.0, std=self.initializer_range)
if self.bias is not None:
self.bias.zero_()
def forward(self, x):
"""
x: (D,) or (S, D) or (B, S, D)
returns: same shape as x
"""
squeeze_1d = False
squeeze_2d = False
if x.dim() == 1:
x = x.unsqueeze(0).unsqueeze(0)
squeeze_1d = True
elif x.dim() == 2:
x = x.unsqueeze(0)
squeeze_2d = True
logits = self.beta * (x @ self.W.t())
if self.bias is not None:
logits = logits + (self.beta * self.bias).view(1, 1, -1)
A = torch.softmax(logits, dim=-1)
out = -(A @ self.W)
out = self.dropout(out)
if squeeze_1d:
return out.squeeze(0).squeeze(0)
if squeeze_2d:
return out.squeeze(0)
return out
def energy(self, x):
"""
Scalar energy summed over batch and tokens.
"""
if x.dim() == 1:
x = x.unsqueeze(0).unsqueeze(0)
elif x.dim() == 2:
x = x.unsqueeze(0)
logits = self.beta * (x @ self.W.t())
if self.bias is not None:
logits = logits + (self.beta * self.bias).view(1, 1, -1)
return -(1.0 / self.beta) * torch.logsumexp(logits, dim=-1).sum()
class HopfieldMHA(nn.Module):
"""
Multi-head attention update corresponding to the gradient of the ET attention energy.
Conventions:
- input x: (B, S, D) or (S, D)
- attention_mask: (B, S), True/1 = keep, False/0 = pad
"""
def __init__(
self,
embedding_dim,
nheads,
beta=None,
device=None,
dropout=0.0,
initializer_range=0.002,
):
super().__init__()
if embedding_dim % nheads != 0:
raise ValueError(
f"embedding_dim ({embedding_dim}) must be divisible by nheads ({nheads})."
)
self.nheads = nheads
self.head_dim = embedding_dim // nheads
if beta is None:
self.beta = 1.0 / (self.head_dim ** 0.5)
else:
self.beta = float(beta)
self.initializer_range = float(initializer_range)
self.dropout = nn.Dropout(dropout)
self.Wq = nn.Parameter(
torch.empty(nheads, embedding_dim, self.head_dim, device=device)
)
self.Wk = nn.Parameter(
torch.empty(nheads, embedding_dim, self.head_dim, device=device)
)
with torch.no_grad():
self.Wq.normal_(mean=0.0, std=self.initializer_range)
self.Wk.normal_(mean=0.0, std=self.initializer_range)
def _ensure_batch_dim(self, x):
squeeze_back = False
if x.dim() == 2:
x = x.unsqueeze(0)
squeeze_back = True
return x, squeeze_back
def _resolve_keep_mask(self, attention_mask=None):
if attention_mask is not None:
return attention_mask.to(torch.bool)
return None
def _no_self_mask(self, seq_len, device):
return torch.eye(seq_len, dtype=torch.bool, device=device)
def forward(
self,
x,
attention_mask=None,
allow_self=True,
):
"""
Returns:
tensor of shape (B, S, D) or (S, D)
"""
x, squeeze_back = self._ensure_batch_dim(x)
B, S, _ = x.shape
keep = self._resolve_keep_mask(attention_mask)
Q = x.unsqueeze(1) @ self.Wq
K = x.unsqueeze(1) @ self.Wk
WqT = self.Wq.transpose(-1, -2)
WkT = self.Wk.transpose(-1, -2)
V1 = K @ WqT
V2 = Q @ WkT
# SDPA computes softmax(QK^T / sqrt(dk)).
# We want softmax(beta * QK^T), so we rescale Q by beta * sqrt(dk)
# in order to match the same logits as in energy() and in the exact T2 path.
dk = Q.shape[-1]
q_scale = self.beta * (dk ** 0.5)
Qs = Q * q_scale
sdpa_mask = None
if keep is not None:
sdpa_mask = keep.view(B, 1, 1, S)
if not allow_self:
no_self = ~self._no_self_mask(S, x.device)
no_self = no_self.view(1, 1, S, S)
sdpa_mask = no_self if sdpa_mask is None else (sdpa_mask & no_self)
with sdpa_kernel(
[SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
):
T1 = -F.scaled_dot_product_attention(
Qs,
K,
V1,
attn_mask=sdpa_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=False,
)
logits = self.beta * (Q @ K.transpose(-2, -1))
neg_inf = -float("inf")
if keep is not None:
logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf)
if not allow_self:
logits = logits.masked_fill(
self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf
)
A = torch.softmax(logits.float(), dim=-1).to(logits.dtype)
A = self.dropout(A)
T2 = -(A.transpose(-2, -1) @ V2)
out = (T1 + T2).sum(dim=1)
if squeeze_back:
out = out.squeeze(0)
return out
def energy(
self,
x,
attention_mask=None,
allow_self=True,
):
"""
Scalar energy summed over batch, heads and tokens.
E = -(1 / beta) * sum logsumexp(beta * <Q, K> + mask)
"""
x, _ = self._ensure_batch_dim(x)
B, S, _ = x.shape
keep = self._resolve_keep_mask(attention_mask)
Q = x.unsqueeze(1) @ self.Wq
K = x.unsqueeze(1) @ self.Wk
logits = self.beta * (Q @ K.transpose(-2, -1))
neg_inf = -float("inf")
if keep is not None:
logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf)
if not allow_self:
logits = logits.masked_fill(
self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf
)
lse = torch.logsumexp(logits, dim=-1)
return -(1.0 / self.beta) * lse.sum()
class HopfieldLayer(nn.Module):
def __init__(
self,
embedding_dim,
nheads,
forward_memories,
forward_activation="relu",
beta=1.0,
bias=True,
device=None,
dropout=0.0,
initializer_range=0.002,
):
super().__init__()
if forward_activation == "relu":
self.ffn = HopfieldReLU(
embedding_dim=embedding_dim,
nmemories=forward_memories,
bias=bias,
device=device,
dropout=dropout,
initializer_range=initializer_range,
)
elif forward_activation == "softmax":
self.ffn = HopfieldSoftmax(
embedding_dim=embedding_dim,
nmemories=forward_memories,
beta=beta,
bias=bias,
device=device,
dropout=dropout,
initializer_range=initializer_range,
)
else:
raise ValueError(
f"Not implemented forward_activation='{forward_activation}'. "
"Expected one of: 'relu', 'softmax'."
)
self.mha = HopfieldMHA(
embedding_dim=embedding_dim,
nheads=nheads,
beta=beta,
device=device,
dropout=dropout,
initializer_range=initializer_range,
)
def energy(
self,
x,
attention_mask=None,
allow_self=True,
):
return self.mha.energy(
x,
attention_mask=attention_mask,
allow_self=allow_self,
) + self.ffn.energy(x)
def forward(
self,
x,
attention_mask=None,
allow_self=True,
):
ffn_output = self.ffn(x)
mha_output = self.mha(
x,
attention_mask=attention_mask,
allow_self=allow_self,
)
return ffn_output + mha_output
if __name__ == "__main__":
def grad_energy(module, x, **kwargs):
x = x.clone().detach().requires_grad_(True)
return torch.func.grad(lambda z: module.energy(z, **kwargs))(x)
def check_module(name, module, x, atol=1e-5, rtol=1e-4, **kwargs):
module.eval()
with torch.no_grad():
forward_out = module(x, **kwargs)
grad_out = grad_energy(module, x, **kwargs)
ok = torch.allclose(forward_out, grad_out, atol=atol, rtol=rtol)
max_diff = (forward_out - grad_out).abs().max().item()
print(f"\n=== {name} ===")
print("forward shape :", tuple(forward_out.shape))
print("grad shape :", tuple(grad_out.shape))
print("allclose :", ok)
print("max abs diff :", max_diff)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
x = torch.randn(2, 4, 12, device=device)
keep_mask = torch.tensor(
[[1, 1, 1, 0],
[1, 1, 0, 0]],
dtype=torch.bool,
device=device,
)
print("Testing elementary blocks against grad(E)...")
hrelu = HopfieldReLU(12, 8, bias=False, device=device, dropout=0.0)
check_module("HopfieldReLU", hrelu, x)
hsoftmax = HopfieldSoftmax(12, 8, beta=1.0, bias=False, device=device, dropout=0.0)
check_module("HopfieldSoftmax", hsoftmax, x)
mha = HopfieldMHA(12, 3, beta=1.0, device=device, dropout=0.0)
check_module("HopfieldMHA", mha, x, attention_mask=keep_mask)
layer = HopfieldLayer(
embedding_dim=12,
nheads=3,
forward_memories=16,
forward_activation="relu",
beta=1.0,
bias=True,
device=device,
dropout=0.0,
)
check_module("HopfieldLayer", layer, x, attention_mask=keep_mask)
print("\nTesting original-style external normalization...")
norm = nn.LayerNorm(12).to(device).eval()
g = norm(x)
with torch.no_grad():
update = layer(g, attention_mask=keep_mask)
g_req = g.clone().detach().requires_grad_(True)
grad_g = torch.func.grad(lambda z: layer.energy(z, attention_mask=keep_mask))(g_req)
print("\n=== HopfieldLayer on normalized input g ===")
print("allclose :", torch.allclose(update, grad_g, atol=1e-5, rtol=1e-4))
print("max abs diff :", (update - grad_g).abs().max().item())