TCMVince commited on
Commit
4c3a425
·
verified ·
1 Parent(s): 587e3b7

Upload hopfield.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hopfield.py +484 -0
hopfield.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.attention import SDPBackend, sdpa_kernel
7
+
8
+
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.attention import SDPBackend, sdpa_kernel
14
+
15
+
16
+ class HopfieldReLU(nn.Module):
17
+ """
18
+ Hopfield ReLU block.
19
+
20
+ Forward:
21
+ - ReLU(x W^T + b) @ W
22
+
23
+ Energy:
24
+ -0.5 * sum ReLU(x W^T + b)^2
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ embedding_dim,
30
+ nmemories,
31
+ bias=True,
32
+ device=None,
33
+ dropout=0.0,
34
+ initializer_range=0.002,
35
+ ):
36
+ super().__init__()
37
+ self.initializer_range = float(initializer_range)
38
+
39
+ self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device))
40
+ self.dropout = nn.Dropout(dropout)
41
+
42
+ if bias:
43
+ self.bias = nn.Parameter(torch.empty(nmemories, device=device))
44
+ else:
45
+ self.bias = None
46
+
47
+ with torch.no_grad():
48
+ self.W.normal_(mean=0.0, std=self.initializer_range)
49
+ if self.bias is not None:
50
+ self.bias.zero_()
51
+
52
+ def forward(self, x):
53
+ """
54
+ x: (B, S, D) or (S, D)
55
+ returns: same shape as x
56
+ """
57
+ squeeze_back = False
58
+ if x.dim() == 2:
59
+ x = x.unsqueeze(0)
60
+ squeeze_back = True
61
+
62
+ H = x @ self.W.t()
63
+ if self.bias is not None:
64
+ H = H + self.bias.view(1, 1, -1)
65
+
66
+ out = -(torch.relu(H) @ self.W)
67
+ out = self.dropout(out)
68
+
69
+ if squeeze_back:
70
+ out = out.squeeze(0)
71
+ return out
72
+
73
+ def energy(self, x):
74
+ """
75
+ Scalar energy summed over batch and tokens.
76
+ """
77
+ if x.dim() == 1:
78
+ x = x.unsqueeze(0).unsqueeze(0)
79
+ elif x.dim() == 2:
80
+ x = x.unsqueeze(0)
81
+
82
+ H = x @ self.W.t()
83
+ if self.bias is not None:
84
+ H = H + self.bias.view(1, 1, -1)
85
+
86
+ return -0.5 * (torch.relu(H) ** 2).sum()
87
+
88
+
89
+ class HopfieldSoftmax(nn.Module):
90
+ """
91
+ Hopfield Softmax block (energy-based CHNSoftmax analogue).
92
+
93
+ Energy:
94
+ E(x) = -(1 / beta) * sum logsumexp(beta * (x W^T + b))
95
+
96
+ Forward:
97
+ - softmax(beta * (x W^T + b)) @ W
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ embedding_dim,
103
+ nmemories,
104
+ beta=1.0,
105
+ bias=True,
106
+ device=None,
107
+ dropout=0.0,
108
+ initializer_range=0.002,
109
+ ):
110
+ super().__init__()
111
+ self.beta = float(beta)
112
+ self.initializer_range = float(initializer_range)
113
+
114
+ self.W = nn.Parameter(torch.empty(nmemories, embedding_dim, device=device))
115
+ self.dropout = nn.Dropout(dropout)
116
+
117
+ if bias:
118
+ self.bias = nn.Parameter(torch.empty(nmemories, device=device))
119
+ else:
120
+ self.bias = None
121
+
122
+ with torch.no_grad():
123
+ self.W.normal_(mean=0.0, std=self.initializer_range)
124
+ if self.bias is not None:
125
+ self.bias.zero_()
126
+
127
+ def forward(self, x):
128
+ """
129
+ x: (D,) or (S, D) or (B, S, D)
130
+ returns: same shape as x
131
+ """
132
+ squeeze_1d = False
133
+ squeeze_2d = False
134
+
135
+ if x.dim() == 1:
136
+ x = x.unsqueeze(0).unsqueeze(0)
137
+ squeeze_1d = True
138
+ elif x.dim() == 2:
139
+ x = x.unsqueeze(0)
140
+ squeeze_2d = True
141
+
142
+ logits = self.beta * (x @ self.W.t())
143
+ if self.bias is not None:
144
+ logits = logits + (self.beta * self.bias).view(1, 1, -1)
145
+
146
+ A = torch.softmax(logits, dim=-1)
147
+ out = -(A @ self.W)
148
+ out = self.dropout(out)
149
+
150
+ if squeeze_1d:
151
+ return out.squeeze(0).squeeze(0)
152
+ if squeeze_2d:
153
+ return out.squeeze(0)
154
+ return out
155
+
156
+ def energy(self, x):
157
+ """
158
+ Scalar energy summed over batch and tokens.
159
+ """
160
+ if x.dim() == 1:
161
+ x = x.unsqueeze(0).unsqueeze(0)
162
+ elif x.dim() == 2:
163
+ x = x.unsqueeze(0)
164
+
165
+ logits = self.beta * (x @ self.W.t())
166
+ if self.bias is not None:
167
+ logits = logits + (self.beta * self.bias).view(1, 1, -1)
168
+
169
+ return -(1.0 / self.beta) * torch.logsumexp(logits, dim=-1).sum()
170
+
171
+
172
+ class HopfieldMHA(nn.Module):
173
+ """
174
+ Multi-head attention update corresponding to the gradient of the ET attention energy.
175
+
176
+ Conventions:
177
+ - input x: (B, S, D) or (S, D)
178
+ - attention_mask: (B, S), True/1 = keep, False/0 = pad
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ embedding_dim,
184
+ nheads,
185
+ beta=None,
186
+ device=None,
187
+ dropout=0.0,
188
+ initializer_range=0.002,
189
+ ):
190
+ super().__init__()
191
+
192
+ if embedding_dim % nheads != 0:
193
+ raise ValueError(
194
+ f"embedding_dim ({embedding_dim}) must be divisible by nheads ({nheads})."
195
+ )
196
+
197
+ self.nheads = nheads
198
+ self.head_dim = embedding_dim // nheads
199
+
200
+ if beta is None:
201
+ self.beta = 1.0 / (self.head_dim ** 0.5)
202
+ else:
203
+ self.beta = float(beta)
204
+
205
+ self.initializer_range = float(initializer_range)
206
+ self.dropout = nn.Dropout(dropout)
207
+
208
+ self.Wq = nn.Parameter(
209
+ torch.empty(nheads, embedding_dim, self.head_dim, device=device)
210
+ )
211
+ self.Wk = nn.Parameter(
212
+ torch.empty(nheads, embedding_dim, self.head_dim, device=device)
213
+ )
214
+
215
+ with torch.no_grad():
216
+ self.Wq.normal_(mean=0.0, std=self.initializer_range)
217
+ self.Wk.normal_(mean=0.0, std=self.initializer_range)
218
+
219
+ def _ensure_batch_dim(self, x):
220
+ squeeze_back = False
221
+ if x.dim() == 2:
222
+ x = x.unsqueeze(0)
223
+ squeeze_back = True
224
+ return x, squeeze_back
225
+
226
+ def _resolve_keep_mask(self, attention_mask=None):
227
+ if attention_mask is not None:
228
+ return attention_mask.to(torch.bool)
229
+ return None
230
+
231
+ def _no_self_mask(self, seq_len, device):
232
+ return torch.eye(seq_len, dtype=torch.bool, device=device)
233
+
234
+ def forward(
235
+ self,
236
+ x,
237
+ attention_mask=None,
238
+ allow_self=True,
239
+ ):
240
+ """
241
+ Returns:
242
+ tensor of shape (B, S, D) or (S, D)
243
+ """
244
+ x, squeeze_back = self._ensure_batch_dim(x)
245
+ B, S, _ = x.shape
246
+
247
+ keep = self._resolve_keep_mask(attention_mask)
248
+
249
+ Q = x.unsqueeze(1) @ self.Wq
250
+ K = x.unsqueeze(1) @ self.Wk
251
+
252
+ WqT = self.Wq.transpose(-1, -2)
253
+ WkT = self.Wk.transpose(-1, -2)
254
+
255
+ V1 = K @ WqT
256
+ V2 = Q @ WkT
257
+
258
+ # SDPA computes softmax(QK^T / sqrt(dk)).
259
+ # We want softmax(beta * QK^T), so we rescale Q by beta * sqrt(dk)
260
+ # in order to match the same logits as in energy() and in the exact T2 path.
261
+ dk = Q.shape[-1]
262
+ q_scale = self.beta * (dk ** 0.5)
263
+ Qs = Q * q_scale
264
+
265
+ sdpa_mask = None
266
+ if keep is not None:
267
+ sdpa_mask = keep.view(B, 1, 1, S)
268
+
269
+ if not allow_self:
270
+ no_self = ~self._no_self_mask(S, x.device)
271
+ no_self = no_self.view(1, 1, S, S)
272
+ sdpa_mask = no_self if sdpa_mask is None else (sdpa_mask & no_self)
273
+
274
+ with sdpa_kernel(
275
+ [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
276
+ ):
277
+ T1 = -F.scaled_dot_product_attention(
278
+ Qs,
279
+ K,
280
+ V1,
281
+ attn_mask=sdpa_mask,
282
+ dropout_p=self.dropout.p if self.training else 0.0,
283
+ is_causal=False,
284
+ )
285
+
286
+ logits = self.beta * (Q @ K.transpose(-2, -1))
287
+ neg_inf = -float("inf")
288
+
289
+ if keep is not None:
290
+ logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf)
291
+
292
+ if not allow_self:
293
+ logits = logits.masked_fill(
294
+ self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf
295
+ )
296
+
297
+ A = torch.softmax(logits.float(), dim=-1).to(logits.dtype)
298
+ A = self.dropout(A)
299
+
300
+ T2 = -(A.transpose(-2, -1) @ V2)
301
+ out = (T1 + T2).sum(dim=1)
302
+
303
+ if squeeze_back:
304
+ out = out.squeeze(0)
305
+ return out
306
+
307
+ def energy(
308
+ self,
309
+ x,
310
+ attention_mask=None,
311
+ allow_self=True,
312
+ ):
313
+ """
314
+ Scalar energy summed over batch, heads and tokens.
315
+
316
+ E = -(1 / beta) * sum logsumexp(beta * <Q, K> + mask)
317
+ """
318
+ x, _ = self._ensure_batch_dim(x)
319
+ B, S, _ = x.shape
320
+
321
+ keep = self._resolve_keep_mask(attention_mask)
322
+
323
+ Q = x.unsqueeze(1) @ self.Wq
324
+ K = x.unsqueeze(1) @ self.Wk
325
+
326
+ logits = self.beta * (Q @ K.transpose(-2, -1))
327
+ neg_inf = -float("inf")
328
+
329
+ if keep is not None:
330
+ logits = logits.masked_fill((~keep).view(B, 1, 1, S), neg_inf)
331
+
332
+ if not allow_self:
333
+ logits = logits.masked_fill(
334
+ self._no_self_mask(S, x.device).view(1, 1, S, S), neg_inf
335
+ )
336
+
337
+ lse = torch.logsumexp(logits, dim=-1)
338
+ return -(1.0 / self.beta) * lse.sum()
339
+
340
+
341
+ class HopfieldLayer(nn.Module):
342
+ def __init__(
343
+ self,
344
+ embedding_dim,
345
+ nheads,
346
+ forward_memories,
347
+ forward_activation="relu",
348
+ beta=1.0,
349
+ bias=True,
350
+ device=None,
351
+ dropout=0.0,
352
+ initializer_range=0.002,
353
+ ):
354
+ super().__init__()
355
+
356
+ if forward_activation == "relu":
357
+ self.ffn = HopfieldReLU(
358
+ embedding_dim=embedding_dim,
359
+ nmemories=forward_memories,
360
+ bias=bias,
361
+ device=device,
362
+ dropout=dropout,
363
+ initializer_range=initializer_range,
364
+ )
365
+ elif forward_activation == "softmax":
366
+ self.ffn = HopfieldSoftmax(
367
+ embedding_dim=embedding_dim,
368
+ nmemories=forward_memories,
369
+ beta=beta,
370
+ bias=bias,
371
+ device=device,
372
+ dropout=dropout,
373
+ initializer_range=initializer_range,
374
+ )
375
+ else:
376
+ raise ValueError(
377
+ f"Not implemented forward_activation='{forward_activation}'. "
378
+ "Expected one of: 'relu', 'softmax'."
379
+ )
380
+
381
+ self.mha = HopfieldMHA(
382
+ embedding_dim=embedding_dim,
383
+ nheads=nheads,
384
+ beta=beta,
385
+ device=device,
386
+ dropout=dropout,
387
+ initializer_range=initializer_range,
388
+ )
389
+
390
+ def energy(
391
+ self,
392
+ x,
393
+ attention_mask=None,
394
+ allow_self=True,
395
+ ):
396
+ return self.mha.energy(
397
+ x,
398
+ attention_mask=attention_mask,
399
+ allow_self=allow_self,
400
+ ) + self.ffn.energy(x)
401
+
402
+ def forward(
403
+ self,
404
+ x,
405
+ attention_mask=None,
406
+ allow_self=True,
407
+ ):
408
+ ffn_output = self.ffn(x)
409
+ mha_output = self.mha(
410
+ x,
411
+ attention_mask=attention_mask,
412
+ allow_self=allow_self,
413
+ )
414
+ return ffn_output + mha_output
415
+
416
+
417
+ if __name__ == "__main__":
418
+
419
+ def grad_energy(module, x, **kwargs):
420
+ x = x.clone().detach().requires_grad_(True)
421
+ return torch.func.grad(lambda z: module.energy(z, **kwargs))(x)
422
+
423
+ def check_module(name, module, x, atol=1e-5, rtol=1e-4, **kwargs):
424
+ module.eval()
425
+ with torch.no_grad():
426
+ forward_out = module(x, **kwargs)
427
+ grad_out = grad_energy(module, x, **kwargs)
428
+
429
+ ok = torch.allclose(forward_out, grad_out, atol=atol, rtol=rtol)
430
+ max_diff = (forward_out - grad_out).abs().max().item()
431
+
432
+ print(f"\n=== {name} ===")
433
+ print("forward shape :", tuple(forward_out.shape))
434
+ print("grad shape :", tuple(grad_out.shape))
435
+ print("allclose :", ok)
436
+ print("max abs diff :", max_diff)
437
+
438
+ device = "cuda" if torch.cuda.is_available() else "cpu"
439
+ torch.manual_seed(0)
440
+
441
+ x = torch.randn(2, 4, 12, device=device)
442
+ keep_mask = torch.tensor(
443
+ [[1, 1, 1, 0],
444
+ [1, 1, 0, 0]],
445
+ dtype=torch.bool,
446
+ device=device,
447
+ )
448
+
449
+ print("Testing elementary blocks against grad(E)...")
450
+
451
+ hrelu = HopfieldReLU(12, 8, bias=False, device=device, dropout=0.0)
452
+ check_module("HopfieldReLU", hrelu, x)
453
+
454
+ hsoftmax = HopfieldSoftmax(12, 8, beta=1.0, bias=False, device=device, dropout=0.0)
455
+ check_module("HopfieldSoftmax", hsoftmax, x)
456
+
457
+ mha = HopfieldMHA(12, 3, beta=1.0, device=device, dropout=0.0)
458
+ check_module("HopfieldMHA", mha, x, attention_mask=keep_mask)
459
+
460
+ layer = HopfieldLayer(
461
+ embedding_dim=12,
462
+ nheads=3,
463
+ forward_memories=16,
464
+ forward_activation="relu",
465
+ beta=1.0,
466
+ bias=True,
467
+ device=device,
468
+ dropout=0.0,
469
+ )
470
+ check_module("HopfieldLayer", layer, x, attention_mask=keep_mask)
471
+
472
+ print("\nTesting original-style external normalization...")
473
+ norm = nn.LayerNorm(12).to(device).eval()
474
+ g = norm(x)
475
+
476
+ with torch.no_grad():
477
+ update = layer(g, attention_mask=keep_mask)
478
+
479
+ g_req = g.clone().detach().requires_grad_(True)
480
+ grad_g = torch.func.grad(lambda z: layer.energy(z, attention_mask=keep_mask))(g_req)
481
+
482
+ print("\n=== HopfieldLayer on normalized input g ===")
483
+ print("allclose :", torch.allclose(update, grad_g, atol=1e-5, rtol=1e-4))
484
+ print("max abs diff :", (update - grad_g).abs().max().item())