Upload hopfield.py with huggingface_hub
Browse files- hopfield.py +484 -0
hopfield.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import sqrt
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class HopfieldReLU(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
Hopfield ReLU block.
|
| 19 |
+
|
| 20 |
+
Forward:
|
| 21 |
+
- ReLU(x W^T + b) @ W
|
| 22 |
+
|
| 23 |
+
Energy:
|
| 24 |
+
-0.5 * sum ReLU(x W^T + b)^2
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
embedding_dim,
|
| 30 |
+
nmemories,
|
| 31 |
+
bias=True,
|
| 32 |
+
device=None,
|
| 33 |
+
dropout=0.0,
|
| 34 |
+
initializer_range=0.002,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.initializer_range = float(initializer_range)
|
| 38 |
+
|
| 39 |
+
self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device))
|
| 40 |
+
self.dropout = nn.Dropout(dropout)
|
| 41 |
+
|
| 42 |
+
if bias:
|
| 43 |
+
self.bias = nn.Parameter(torch.empty(nmemories, device=device))
|
| 44 |
+
else:
|
| 45 |
+
self.bias = None
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
self.W.normal_(mean=0.0, std=self.initializer_range)
|
| 49 |
+
if self.bias is not None:
|
| 50 |
+
self.bias.zero_()
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
"""
|
| 54 |
+
x: (B, S, D) or (S, D)
|
| 55 |
+
returns: same shape as x
|
| 56 |
+
"""
|
| 57 |
+
squeeze_back = False
|
| 58 |
+
if x.dim() == 2:
|
| 59 |
+
x = x.unsqueeze(0)
|
| 60 |
+
squeeze_back = True
|
| 61 |
+
|
| 62 |
+
H = x @ self.W.t()
|
| 63 |
+
if self.bias is not None:
|
| 64 |
+
H = H + self.bias.view(1, 1, -1)
|
| 65 |
+
|
| 66 |
+
out = -(torch.relu(H) @ self.W)
|
| 67 |
+
out = self.dropout(out)
|
| 68 |
+
|
| 69 |
+
if squeeze_back:
|
| 70 |
+
out = out.squeeze(0)
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
def energy(self, x):
|
| 74 |
+
"""
|
| 75 |
+
Scalar energy summed over batch and tokens.
|
| 76 |
+
"""
|
| 77 |
+
if x.dim() == 1:
|
| 78 |
+
x = x.unsqueeze(0).unsqueeze(0)
|
| 79 |
+
elif x.dim() == 2:
|
| 80 |
+
x = x.unsqueeze(0)
|
| 81 |
+
|
| 82 |
+
H = x @ self.W.t()
|
| 83 |
+
if self.bias is not None:
|
| 84 |
+
H = H + self.bias.view(1, 1, -1)
|
| 85 |
+
|
| 86 |
+
return -0.5 * (torch.relu(H) ** 2).sum()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class HopfieldSoftmax(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
Hopfield Softmax block (energy-based CHNSoftmax analogue).
|
| 92 |
+
|
| 93 |
+
Energy:
|
| 94 |
+
E(x) = -(1 / beta) * sum logsumexp(beta * (x W^T + b))
|
| 95 |
+
|
| 96 |
+
Forward:
|
| 97 |
+
- softmax(beta * (x W^T + b)) @ W
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
embedding_dim,
|
| 103 |
+
nmemories,
|
| 104 |
+
beta=1.0,
|
| 105 |
+
bias=True,
|
| 106 |
+
device=None,
|
| 107 |
+
dropout=0.0,
|
| 108 |
+
initializer_range=0.002,
|
| 109 |
+
):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.beta = float(beta)
|
| 112 |
+
self.initializer_range = float(initializer_range)
|
| 113 |
+
|
| 114 |
+
self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device))
|
| 115 |
+
self.dropout = nn.Dropout(dropout)
|
| 116 |
+
|
| 117 |
+
if bias:
|
| 118 |
+
self.bias = nn.Parameter(torch.empty(nmemories, device=device))
|
| 119 |
+
else:
|
| 120 |
+
self.bias = None
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
self.W.normal_(mean=0.0, std=self.initializer_range)
|
| 124 |
+
if self.bias is not None:
|
| 125 |
+
self.bias.zero_()
|
| 126 |
+
|
| 127 |
+
def forward(self, x):
|
| 128 |
+
"""
|
| 129 |
+
x: (D,) or (S, D) or (B, S, D)
|
| 130 |
+
returns: same shape as x
|
| 131 |
+
"""
|
| 132 |
+
squeeze_1d = False
|
| 133 |
+
squeeze_2d = False
|
| 134 |
+
|
| 135 |
+
if x.dim() == 1:
|
| 136 |
+
x = x.unsqueeze(0).unsqueeze(0)
|
| 137 |
+
squeeze_1d = True
|
| 138 |
+
elif x.dim() == 2:
|
| 139 |
+
x = x.unsqueeze(0)
|
| 140 |
+
squeeze_2d = True
|
| 141 |
+
|
| 142 |
+
logits = self.beta * (x @ self.W.t())
|
| 143 |
+
if self.bias is not None:
|
| 144 |
+
logits = logits + (self.beta * self.bias).view(1, 1, -1)
|
| 145 |
+
|
| 146 |
+
A = torch.softmax(logits, dim=-1)
|
| 147 |
+
out = -(A @ self.W)
|
| 148 |
+
out = self.dropout(out)
|
| 149 |
+
|
| 150 |
+
if squeeze_1d:
|
| 151 |
+
return out.squeeze(0).squeeze(0)
|
| 152 |
+
if squeeze_2d:
|
| 153 |
+
return out.squeeze(0)
|
| 154 |
+
return out
|
| 155 |
+
|
| 156 |
+
def energy(self, x):
|
| 157 |
+
"""
|
| 158 |
+
Scalar energy summed over batch and tokens.
|
| 159 |
+
"""
|
| 160 |
+
if x.dim() == 1:
|
| 161 |
+
x = x.unsqueeze(0).unsqueeze(0)
|
| 162 |
+
elif x.dim() == 2:
|
| 163 |
+
x = x.unsqueeze(0)
|
| 164 |
+
|
| 165 |
+
logits = self.beta * (x @ self.W.t())
|
| 166 |
+
if self.bias is not None:
|
| 167 |
+
logits = logits + (self.beta * self.bias).view(1, 1, -1)
|
| 168 |
+
|
| 169 |
+
return -(1.0 / self.beta) * torch.logsumexp(logits, dim=-1).sum()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class HopfieldMHA(nn.Module):
|
| 173 |
+
"""
|
| 174 |
+
Multi-head attention update corresponding to the gradient of the ET attention energy.
|
| 175 |
+
|
| 176 |
+
Conventions:
|
| 177 |
+
- input x: (B, S, D) or (S, D)
|
| 178 |
+
- attention_mask: (B, S), True/1 = keep, False/0 = pad
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
embedding_dim,
|
| 184 |
+
nheads,
|
| 185 |
+
beta=None,
|
| 186 |
+
device=None,
|
| 187 |
+
dropout=0.0,
|
| 188 |
+
initializer_range=0.002,
|
| 189 |
+
):
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
if embedding_dim % nheads != 0:
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"embedding_dim ({embedding_dim}) must be divisible by nheads ({nheads})."
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.nheads = nheads
|
| 198 |
+
self.head_dim = embedding_dim // nheads
|
| 199 |
+
|
| 200 |
+
if beta is None:
|
| 201 |
+
self.beta = 1.0 / (self.head_dim ** 0.5)
|
| 202 |
+
else:
|
| 203 |
+
self.beta = float(beta)
|
| 204 |
+
|
| 205 |
+
self.initializer_range = float(initializer_range)
|
| 206 |
+
self.dropout = nn.Dropout(dropout)
|
| 207 |
+
|
| 208 |
+
self.Wq = nn.Parameter(
|
| 209 |
+
torch.empty(nheads, embedding_dim, self.head_dim, device=device)
|
| 210 |
+
)
|
| 211 |
+
self.Wk = nn.Parameter(
|
| 212 |
+
torch.empty(nheads, embedding_dim, self.head_dim, device=device)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
self.Wq.normal_(mean=0.0, std=self.initializer_range)
|
| 217 |
+
self.Wk.normal_(mean=0.0, std=self.initializer_range)
|
| 218 |
+
|
| 219 |
+
def _ensure_batch_dim(self, x):
|
| 220 |
+
squeeze_back = False
|
| 221 |
+
if x.dim() == 2:
|
| 222 |
+
x = x.unsqueeze(0)
|
| 223 |
+
squeeze_back = True
|
| 224 |
+
return x, squeeze_back
|
| 225 |
+
|
| 226 |
+
def _resolve_keep_mask(self, attention_mask=None):
|
| 227 |
+
if attention_mask is not None:
|
| 228 |
+
return attention_mask.to(torch.bool)
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
def _no_self_mask(self, seq_len, device):
|
| 232 |
+
return torch.eye(seq_len, dtype=torch.bool, device=device)
|
| 233 |
+
|
| 234 |
+
def forward(
|
| 235 |
+
self,
|
| 236 |
+
x,
|
| 237 |
+
attention_mask=None,
|
| 238 |
+
allow_self=True,
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
Returns:
|
| 242 |
+
tensor of shape (B, S, D) or (S, D)
|
| 243 |
+
"""
|
| 244 |
+
x, squeeze_back = self._ensure_batch_dim(x)
|
| 245 |
+
B, S, _ = x.shape
|
| 246 |
+
|
| 247 |
+
keep = self._resolve_keep_mask(attention_mask)
|
| 248 |
+
|
| 249 |
+
Q = x.unsqueeze(1) @ self.Wq
|
| 250 |
+
K = x.unsqueeze(1) @ self.Wk
|
| 251 |
+
|
| 252 |
+
WqT = self.Wq.transpose(-1, -2)
|
| 253 |
+
WkT = self.Wk.transpose(-1, -2)
|
| 254 |
+
|
| 255 |
+
V1 = K @ WqT
|
| 256 |
+
V2 = Q @ WkT
|
| 257 |
+
|
| 258 |
+
# SDPA computes softmax(QK^T / sqrt(dk)).
|
| 259 |
+
# We want softmax(beta * QK^T), so we rescale Q by beta * sqrt(dk)
|
| 260 |
+
# in order to match the same logits as in energy() and in the exact T2 path.
|
| 261 |
+
dk = Q.shape[-1]
|
| 262 |
+
q_scale = self.beta * (dk ** 0.5)
|
| 263 |
+
Qs = Q * q_scale
|
| 264 |
+
|
| 265 |
+
sdpa_mask = None
|
| 266 |
+
if keep is not None:
|
| 267 |
+
sdpa_mask = keep.view(B, 1, 1, S)
|
| 268 |
+
|
| 269 |
+
if not allow_self:
|
| 270 |
+
no_self = ~self._no_self_mask(S, x.device)
|
| 271 |
+
no_self = no_self.view(1, 1, S, S)
|
| 272 |
+
sdpa_mask = no_self if sdpa_mask is None else (sdpa_mask & no_self)
|
| 273 |
+
|
| 274 |
+
with sdpa_kernel(
|
| 275 |
+
[SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
|
| 276 |
+
):
|
| 277 |
+
T1 = -F.scaled_dot_product_attention(
|
| 278 |
+
Qs,
|
| 279 |
+
K,
|
| 280 |
+
V1,
|
| 281 |
+
attn_mask=sdpa_mask,
|
| 282 |
+
dropout_p=self.dropout.p if self.training else 0.0,
|
| 283 |
+
is_causal=False,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
logits = self.beta * (Q @ K.transpose(-2, -1))
|
| 287 |
+
neg_inf = -float("inf")
|
| 288 |
+
|
| 289 |
+
if keep is not None:
|
| 290 |
+
logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf)
|
| 291 |
+
|
| 292 |
+
if not allow_self:
|
| 293 |
+
logits = logits.masked_fill(
|
| 294 |
+
self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
A = torch.softmax(logits.float(), dim=-1).to(logits.dtype)
|
| 298 |
+
A = self.dropout(A)
|
| 299 |
+
|
| 300 |
+
T2 = -(A.transpose(-2, -1) @ V2)
|
| 301 |
+
out = (T1 + T2).sum(dim=1)
|
| 302 |
+
|
| 303 |
+
if squeeze_back:
|
| 304 |
+
out = out.squeeze(0)
|
| 305 |
+
return out
|
| 306 |
+
|
| 307 |
+
def energy(
|
| 308 |
+
self,
|
| 309 |
+
x,
|
| 310 |
+
attention_mask=None,
|
| 311 |
+
allow_self=True,
|
| 312 |
+
):
|
| 313 |
+
"""
|
| 314 |
+
Scalar energy summed over batch, heads and tokens.
|
| 315 |
+
|
| 316 |
+
E = -(1 / beta) * sum logsumexp(beta * <Q, K> + mask)
|
| 317 |
+
"""
|
| 318 |
+
x, _ = self._ensure_batch_dim(x)
|
| 319 |
+
B, S, _ = x.shape
|
| 320 |
+
|
| 321 |
+
keep = self._resolve_keep_mask(attention_mask)
|
| 322 |
+
|
| 323 |
+
Q = x.unsqueeze(1) @ self.Wq
|
| 324 |
+
K = x.unsqueeze(1) @ self.Wk
|
| 325 |
+
|
| 326 |
+
logits = self.beta * (Q @ K.transpose(-2, -1))
|
| 327 |
+
neg_inf = -float("inf")
|
| 328 |
+
|
| 329 |
+
if keep is not None:
|
| 330 |
+
logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf)
|
| 331 |
+
|
| 332 |
+
if not allow_self:
|
| 333 |
+
logits = logits.masked_fill(
|
| 334 |
+
self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
lse = torch.logsumexp(logits, dim=-1)
|
| 338 |
+
return -(1.0 / self.beta) * lse.sum()
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class HopfieldLayer(nn.Module):
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
embedding_dim,
|
| 345 |
+
nheads,
|
| 346 |
+
forward_memories,
|
| 347 |
+
forward_activation="relu",
|
| 348 |
+
beta=1.0,
|
| 349 |
+
bias=True,
|
| 350 |
+
device=None,
|
| 351 |
+
dropout=0.0,
|
| 352 |
+
initializer_range=0.002,
|
| 353 |
+
):
|
| 354 |
+
super().__init__()
|
| 355 |
+
|
| 356 |
+
if forward_activation == "relu":
|
| 357 |
+
self.ffn = HopfieldReLU(
|
| 358 |
+
embedding_dim=embedding_dim,
|
| 359 |
+
nmemories=forward_memories,
|
| 360 |
+
bias=bias,
|
| 361 |
+
device=device,
|
| 362 |
+
dropout=dropout,
|
| 363 |
+
initializer_range=initializer_range,
|
| 364 |
+
)
|
| 365 |
+
elif forward_activation == "softmax":
|
| 366 |
+
self.ffn = HopfieldSoftmax(
|
| 367 |
+
embedding_dim=embedding_dim,
|
| 368 |
+
nmemories=forward_memories,
|
| 369 |
+
beta=beta,
|
| 370 |
+
bias=bias,
|
| 371 |
+
device=device,
|
| 372 |
+
dropout=dropout,
|
| 373 |
+
initializer_range=initializer_range,
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"Not implemented forward_activation='{forward_activation}'. "
|
| 378 |
+
"Expected one of: 'relu', 'softmax'."
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
self.mha = HopfieldMHA(
|
| 382 |
+
embedding_dim=embedding_dim,
|
| 383 |
+
nheads=nheads,
|
| 384 |
+
beta=beta,
|
| 385 |
+
device=device,
|
| 386 |
+
dropout=dropout,
|
| 387 |
+
initializer_range=initializer_range,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def energy(
|
| 391 |
+
self,
|
| 392 |
+
x,
|
| 393 |
+
attention_mask=None,
|
| 394 |
+
allow_self=True,
|
| 395 |
+
):
|
| 396 |
+
return self.mha.energy(
|
| 397 |
+
x,
|
| 398 |
+
attention_mask=attention_mask,
|
| 399 |
+
allow_self=allow_self,
|
| 400 |
+
) + self.ffn.energy(x)
|
| 401 |
+
|
| 402 |
+
def forward(
|
| 403 |
+
self,
|
| 404 |
+
x,
|
| 405 |
+
attention_mask=None,
|
| 406 |
+
allow_self=True,
|
| 407 |
+
):
|
| 408 |
+
ffn_output = self.ffn(x)
|
| 409 |
+
mha_output = self.mha(
|
| 410 |
+
x,
|
| 411 |
+
attention_mask=attention_mask,
|
| 412 |
+
allow_self=allow_self,
|
| 413 |
+
)
|
| 414 |
+
return ffn_output + mha_output
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
|
| 419 |
+
def grad_energy(module, x, **kwargs):
|
| 420 |
+
x = x.clone().detach().requires_grad_(True)
|
| 421 |
+
return torch.func.grad(lambda z: module.energy(z, **kwargs))(x)
|
| 422 |
+
|
| 423 |
+
def check_module(name, module, x, atol=1e-5, rtol=1e-4, **kwargs):
|
| 424 |
+
module.eval()
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
forward_out = module(x, **kwargs)
|
| 427 |
+
grad_out = grad_energy(module, x, **kwargs)
|
| 428 |
+
|
| 429 |
+
ok = torch.allclose(forward_out, grad_out, atol=atol, rtol=rtol)
|
| 430 |
+
max_diff = (forward_out - grad_out).abs().max().item()
|
| 431 |
+
|
| 432 |
+
print(f"\n=== {name} ===")
|
| 433 |
+
print("forward shape :", tuple(forward_out.shape))
|
| 434 |
+
print("grad shape :", tuple(grad_out.shape))
|
| 435 |
+
print("allclose :", ok)
|
| 436 |
+
print("max abs diff :", max_diff)
|
| 437 |
+
|
| 438 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 439 |
+
torch.manual_seed(0)
|
| 440 |
+
|
| 441 |
+
x = torch.randn(2, 4, 12, device=device)
|
| 442 |
+
keep_mask = torch.tensor(
|
| 443 |
+
[[1, 1, 1, 0],
|
| 444 |
+
[1, 1, 0, 0]],
|
| 445 |
+
dtype=torch.bool,
|
| 446 |
+
device=device,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
print("Testing elementary blocks against grad(E)...")
|
| 450 |
+
|
| 451 |
+
hrelu = HopfieldReLU(12, 8, bias=False, device=device, dropout=0.0)
|
| 452 |
+
check_module("HopfieldReLU", hrelu, x)
|
| 453 |
+
|
| 454 |
+
hsoftmax = HopfieldSoftmax(12, 8, beta=1.0, bias=False, device=device, dropout=0.0)
|
| 455 |
+
check_module("HopfieldSoftmax", hsoftmax, x)
|
| 456 |
+
|
| 457 |
+
mha = HopfieldMHA(12, 3, beta=1.0, device=device, dropout=0.0)
|
| 458 |
+
check_module("HopfieldMHA", mha, x, attention_mask=keep_mask)
|
| 459 |
+
|
| 460 |
+
layer = HopfieldLayer(
|
| 461 |
+
embedding_dim=12,
|
| 462 |
+
nheads=3,
|
| 463 |
+
forward_memories=16,
|
| 464 |
+
forward_activation="relu",
|
| 465 |
+
beta=1.0,
|
| 466 |
+
bias=True,
|
| 467 |
+
device=device,
|
| 468 |
+
dropout=0.0,
|
| 469 |
+
)
|
| 470 |
+
check_module("HopfieldLayer", layer, x, attention_mask=keep_mask)
|
| 471 |
+
|
| 472 |
+
print("\nTesting original-style external normalization...")
|
| 473 |
+
norm = nn.LayerNorm(12).to(device).eval()
|
| 474 |
+
g = norm(x)
|
| 475 |
+
|
| 476 |
+
with torch.no_grad():
|
| 477 |
+
update = layer(g, attention_mask=keep_mask)
|
| 478 |
+
|
| 479 |
+
g_req = g.clone().detach().requires_grad_(True)
|
| 480 |
+
grad_g = torch.func.grad(lambda z: layer.energy(z, attention_mask=keep_mask))(g_req)
|
| 481 |
+
|
| 482 |
+
print("\n=== HopfieldLayer on normalized input g ===")
|
| 483 |
+
print("allclose :", torch.allclose(update, grad_g, atol=1e-5, rtol=1e-4))
|
| 484 |
+
print("max abs diff :", (update - grad_g).abs().max().item())
|