File size: 12,982 Bytes
5c85f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"""Transformer model for autoregressive Ising spin generation.

Architecture: causal (GPT-style) transformer with per-site positional
embeddings in snake (boustrophedon) order.  The model is trained to maximise
p(s_0, s_1, ..., s_{N-1}) = ∏_t p(s_t | s_0, ..., s_{t-1}), where the spin
sites are visited in snake order over the L×L lattice.
"""

from collections.abc import Mapping

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float, Int


def snake_order(size: int) -> tuple[np.ndarray, np.ndarray]:
    """Return (rows, cols) index arrays traversing an L×L grid in snake order.

    Even rows go left-to-right; odd rows go right-to-left.  The returned
    arrays have length size² and implement numpy advanced indexing:
        grid[rows, cols]  →  1-D sequence in snake order
        grid[rows, cols] = seq  →  scatter a sequence back to the grid
    """
    if size <= 0:
        raise ValueError("size must be positive")
    rows, cols = [], []
    for row in range(size):
        columns = range(size) if row % 2 == 0 else range(size - 1, -1, -1)
        for col in columns:
            rows.append(row)
            cols.append(col)
    return np.array(rows), np.array(cols)


# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------

class EmbedderBlock(eqx.Module):
    """Spin-state + lattice-position embedder.

    Each position in the snake-order sequence gets three embeddings summed:
      • a learned spin-state embedding  (token  ∈ {0, 1})
      • a learned row-position embedding
      • a learned column-position embedding

    The row/column indices are derived from `snake_order` at trace time, so
    they fold to compile-time constants — no array model-parameters are stored.
    """

    state_embedder: eqx.nn.Embedding
    row_embedder: eqx.nn.Embedding
    column_embedder: eqx.nn.Embedding
    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout
    lattice_size: int = eqx.field(static=True)

    def __init__(
        self,
        state_size: int,
        lattice_size: int,
        embedding_size: int,
        hidden_size: int,
        dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        state_key, row_key, col_key = jax.random.split(key, 3)
        self.state_embedder = eqx.nn.Embedding(
            num_embeddings=state_size, embedding_size=embedding_size, key=state_key
        )
        self.row_embedder = eqx.nn.Embedding(
            num_embeddings=lattice_size, embedding_size=embedding_size, key=row_key
        )
        self.column_embedder = eqx.nn.Embedding(
            num_embeddings=lattice_size, embedding_size=embedding_size, key=col_key
        )
        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)
        self.lattice_size = lattice_size

    def __call__(
        self,
        states: Int[Array, " seq_len"],
        enable_dropout: bool = False,
        key: jax.Array | None = None,
    ) -> Float[Array, "seq_len hidden_size"]:
        rows, cols = snake_order(self.lattice_size)   # concrete at trace time
        x_states = jax.vmap(self.state_embedder)(states)
        x_rows   = jax.vmap(self.row_embedder)(jnp.asarray(rows))
        x_cols   = jax.vmap(self.column_embedder)(jnp.asarray(cols))
        x = x_states + x_rows + x_cols
        x = jax.vmap(self.layernorm)(x)
        x = self.dropout(x, inference=not enable_dropout, key=key)
        return x


class FeedForwardBlock(eqx.Module):
    """Position-wise feed-forward block with residual connection."""

    mlp: eqx.nn.Linear
    output: eqx.nn.Linear
    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        mlp_key, out_key = jax.random.split(key)
        self.mlp    = eqx.nn.Linear(hidden_size, intermediate_size, key=mlp_key)
        self.output = eqx.nn.Linear(intermediate_size, hidden_size, key=out_key)
        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout   = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: Float[Array, " hidden_size"],
        enable_dropout: bool = False,
        key: jax.Array | None = None,
    ) -> Float[Array, " hidden_size"]:
        x = jax.nn.gelu(self.mlp(inputs))
        x = self.output(x)
        x = self.dropout(x, inference=not enable_dropout, key=key)
        x = x + inputs
        x = self.layernorm(x)
        return x


class AttentionBlock(eqx.Module):
    """Multi-head self-attention with causal (lower-triangular) mask."""

    attention: eqx.nn.MultiheadAttention
    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout
    num_heads: int = eqx.field(static=True)

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        self.num_heads = num_heads
        self.attention = eqx.nn.MultiheadAttention(
            num_heads=num_heads,
            query_size=hidden_size,
            use_query_bias=True,
            use_key_bias=True,
            use_value_bias=True,
            use_output_bias=True,
            dropout_p=attention_dropout_rate,
            key=key,
        )
        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout   = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        mask: Int[Array, " seq_len"] | None,
        enable_dropout: bool = False,
        key: jax.random.PRNGKey = None,
    ) -> Float[Array, "seq_len hidden_size"]:
        attn_key, drop_key = (None, None) if key is None else jax.random.split(key)
        if mask is not None:
            mask = self._causal_mask(mask)
        x = self.attention(
            query=inputs, key_=inputs, value=inputs,
            mask=mask, inference=not enable_dropout, key=attn_key,
        )
        x = self.dropout(x, inference=not enable_dropout, key=drop_key)
        x = x + inputs
        x = jax.vmap(self.layernorm)(x)
        return x

    def _causal_mask(
        self, mask: Int[Array, " seq_len"]
    ) -> Float[Array, "num_heads seq_len seq_len"]:
        """Lower-triangular mask combined with a padding mask."""
        n = mask.shape[0]
        pad   = jnp.multiply(mask[:, None], mask[None, :])          # [n, n]
        causal = jnp.tril(jnp.ones((n, n), dtype=mask.dtype))       # [n, n]
        m = jnp.multiply(pad, causal)                                # [n, n]
        m = jnp.broadcast_to(m[None], (self.num_heads, n, n))        # [H, n, n]
        return m.astype(jnp.float32)


class TransformerLayer(eqx.Module):
    """One transformer block: attention followed by feed-forward."""

    attention_block: AttentionBlock
    ff_block: FeedForwardBlock

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        attn_key, ff_key = jax.random.split(key)
        self.attention_block = AttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=attn_key,
        )
        self.ff_block = FeedForwardBlock(
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            dropout_rate=dropout_rate,
            key=ff_key,
        )

    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        mask: Int[Array, " seq_len"] | None = None,
        *,
        enable_dropout: bool = False,
        key: jax.Array | None = None,
    ) -> Float[Array, "seq_len hidden_size"]:
        attn_key, ff_key = (None, None) if key is None else jax.random.split(key)
        x = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key)
        n = x.shape[0]
        ff_keys = None if ff_key is None else jax.random.split(ff_key, n)
        x = jax.vmap(self.ff_block, in_axes=(0, None, 0))(x, enable_dropout, ff_keys)
        return x


# ---------------------------------------------------------------------------
# Encoder and top-level Generator
# ---------------------------------------------------------------------------

class Encoder(eqx.Module):
    """Stack of transformer layers over a snake-ordered spin sequence."""

    embedder_block: EmbedderBlock
    layers: list[TransformerLayer]

    def __init__(
        self,
        state_size: int,
        lattice_size: int,
        embedding_size: int,
        hidden_size: int,
        intermediate_size: int,
        num_layers: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jax.random.PRNGKey,
    ):
        emb_key, layer_key = jax.random.split(key)
        self.embedder_block = EmbedderBlock(
            state_size=state_size,
            lattice_size=lattice_size,
            embedding_size=embedding_size,
            hidden_size=hidden_size,
            dropout_rate=dropout_rate,
            key=emb_key,
        )
        layer_keys = jax.random.split(layer_key, num_layers)
        self.layers = [
            TransformerLayer(
                hidden_size=hidden_size,
                intermediate_size=intermediate_size,
                num_heads=num_heads,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                key=lk,
            )
            for lk in layer_keys
        ]

    def __call__(
        self,
        states: Int[Array, " seq_len"],
        *,
        enable_dropout: bool = False,
        key: jax.Array | None = None,
    ) -> Float[Array, "seq_len hidden_size"]:
        emb_key, l_key = (None, None) if key is None else jax.random.split(key)
        x = self.embedder_block(states, enable_dropout=enable_dropout, key=emb_key)
        mask = jnp.ones_like(states, dtype=jnp.int32)  # no padding; causal only
        for layer in self.layers:
            cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key)
            x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key)
        return x


class Generator(eqx.Module):
    """Autoregressive transformer generator for Ising spin configurations.

    Input:  token_ids — integer spin tokens {0=down, 1=up} in snake order.
    Output: logits    — shape (seq_len, state_size), where logits[t] is the
                        predicted distribution over the spin at position t+1
                        given positions 0..t.
    """

    encoder: Encoder
    lm_head: eqx.nn.Linear
    dropout: eqx.nn.Dropout

    def __init__(self, config: Mapping, key: jax.random.PRNGKey):
        enc_key, head_key = jax.random.split(key)
        self.encoder = Encoder(
            state_size=config["state_size"],
            lattice_size=config["lattice_size"],
            embedding_size=config["hidden_size"],
            hidden_size=config["hidden_size"],
            intermediate_size=config["intermediate_size"],
            num_layers=config["num_hidden_layers"],
            num_heads=config["num_attention_heads"],
            dropout_rate=config["hidden_dropout_prob"],
            attention_dropout_rate=config["attention_probs_dropout_prob"],
            key=enc_key,
        )
        self.lm_head = eqx.nn.Linear(
            in_features=config["hidden_size"],
            out_features=config["state_size"],
            key=head_key,
        )
        self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"])

    def __call__(
        self,
        inputs: dict[str, Int[Array, " seq_len"]],
        enable_dropout: bool = False,
        key: jax.random.PRNGKey = None,
    ) -> Float[Array, "seq_len state_size"]:
        e_key, d_key = (None, None) if key is None else jax.random.split(key)
        x = self.encoder(inputs["token_ids"], enable_dropout=enable_dropout, key=e_key)
        x = self.dropout(x, inference=not enable_dropout, key=d_key)
        return jax.vmap(self.lm_head)(x)


# ---------------------------------------------------------------------------
# Default configuration
# ---------------------------------------------------------------------------

gen_config = {
    "state_size": 2,            # spin tokens: 0 (↓) or 1 (↑)
    "lattice_size": 32,         # L×L lattice → L² = 1024 sequence length
    "hidden_size": 128,
    "num_hidden_layers": 2,
    "num_attention_heads": 2,
    "hidden_act": "gelu",
    "intermediate_size": 512,
    "hidden_dropout_prob": 0.1,
    "attention_probs_dropout_prob": 0.1,
}