"""Transformer model for autoregressive Ising spin generation. Architecture: causal (GPT-style) transformer with per-site positional embeddings in snake (boustrophedon) order. The model is trained to maximise p(s_0, s_1, ..., s_{N-1}) = ∏_t p(s_t | s_0, ..., s_{t-1}), where the spin sites are visited in snake order over the L×L lattice. """ from collections.abc import Mapping import equinox as eqx import jax import jax.numpy as jnp import numpy as np from jaxtyping import Array, Float, Int def snake_order(size: int) -> tuple[np.ndarray, np.ndarray]: """Return (rows, cols) index arrays traversing an L×L grid in snake order. Even rows go left-to-right; odd rows go right-to-left. The returned arrays have length size² and implement numpy advanced indexing: grid[rows, cols] → 1-D sequence in snake order grid[rows, cols] = seq → scatter a sequence back to the grid """ if size <= 0: raise ValueError("size must be positive") rows, cols = [], [] for row in range(size): columns = range(size) if row % 2 == 0 else range(size - 1, -1, -1) for col in columns: rows.append(row) cols.append(col) return np.array(rows), np.array(cols) # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- class EmbedderBlock(eqx.Module): """Spin-state + lattice-position embedder. Each position in the snake-order sequence gets three embeddings summed: • a learned spin-state embedding (token ∈ {0, 1}) • a learned row-position embedding • a learned column-position embedding The row/column indices are derived from `snake_order` at trace time, so they fold to compile-time constants — no array model-parameters are stored. """ state_embedder: eqx.nn.Embedding row_embedder: eqx.nn.Embedding column_embedder: eqx.nn.Embedding layernorm: eqx.nn.LayerNorm dropout: eqx.nn.Dropout lattice_size: int = eqx.field(static=True) def __init__( self, state_size: int, lattice_size: int, embedding_size: int, hidden_size: int, dropout_rate: float, key: jax.random.PRNGKey, ): state_key, row_key, col_key = jax.random.split(key, 3) self.state_embedder = eqx.nn.Embedding( num_embeddings=state_size, embedding_size=embedding_size, key=state_key ) self.row_embedder = eqx.nn.Embedding( num_embeddings=lattice_size, embedding_size=embedding_size, key=row_key ) self.column_embedder = eqx.nn.Embedding( num_embeddings=lattice_size, embedding_size=embedding_size, key=col_key ) self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) self.dropout = eqx.nn.Dropout(dropout_rate) self.lattice_size = lattice_size def __call__( self, states: Int[Array, " seq_len"], enable_dropout: bool = False, key: jax.Array | None = None, ) -> Float[Array, "seq_len hidden_size"]: rows, cols = snake_order(self.lattice_size) # concrete at trace time x_states = jax.vmap(self.state_embedder)(states) x_rows = jax.vmap(self.row_embedder)(jnp.asarray(rows)) x_cols = jax.vmap(self.column_embedder)(jnp.asarray(cols)) x = x_states + x_rows + x_cols x = jax.vmap(self.layernorm)(x) x = self.dropout(x, inference=not enable_dropout, key=key) return x class FeedForwardBlock(eqx.Module): """Position-wise feed-forward block with residual connection.""" mlp: eqx.nn.Linear output: eqx.nn.Linear layernorm: eqx.nn.LayerNorm dropout: eqx.nn.Dropout def __init__( self, hidden_size: int, intermediate_size: int, dropout_rate: float, key: jax.random.PRNGKey, ): mlp_key, out_key = jax.random.split(key) self.mlp = eqx.nn.Linear(hidden_size, intermediate_size, key=mlp_key) self.output = eqx.nn.Linear(intermediate_size, hidden_size, key=out_key) self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) self.dropout = eqx.nn.Dropout(dropout_rate) def __call__( self, inputs: Float[Array, " hidden_size"], enable_dropout: bool = False, key: jax.Array | None = None, ) -> Float[Array, " hidden_size"]: x = jax.nn.gelu(self.mlp(inputs)) x = self.output(x) x = self.dropout(x, inference=not enable_dropout, key=key) x = x + inputs x = self.layernorm(x) return x class AttentionBlock(eqx.Module): """Multi-head self-attention with causal (lower-triangular) mask.""" attention: eqx.nn.MultiheadAttention layernorm: eqx.nn.LayerNorm dropout: eqx.nn.Dropout num_heads: int = eqx.field(static=True) def __init__( self, hidden_size: int, num_heads: int, dropout_rate: float, attention_dropout_rate: float, key: jax.random.PRNGKey, ): self.num_heads = num_heads self.attention = eqx.nn.MultiheadAttention( num_heads=num_heads, query_size=hidden_size, use_query_bias=True, use_key_bias=True, use_value_bias=True, use_output_bias=True, dropout_p=attention_dropout_rate, key=key, ) self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) self.dropout = eqx.nn.Dropout(dropout_rate) def __call__( self, inputs: Float[Array, "seq_len hidden_size"], mask: Int[Array, " seq_len"] | None, enable_dropout: bool = False, key: jax.random.PRNGKey = None, ) -> Float[Array, "seq_len hidden_size"]: attn_key, drop_key = (None, None) if key is None else jax.random.split(key) if mask is not None: mask = self._causal_mask(mask) x = self.attention( query=inputs, key_=inputs, value=inputs, mask=mask, inference=not enable_dropout, key=attn_key, ) x = self.dropout(x, inference=not enable_dropout, key=drop_key) x = x + inputs x = jax.vmap(self.layernorm)(x) return x def _causal_mask( self, mask: Int[Array, " seq_len"] ) -> Float[Array, "num_heads seq_len seq_len"]: """Lower-triangular mask combined with a padding mask.""" n = mask.shape[0] pad = jnp.multiply(mask[:, None], mask[None, :]) # [n, n] causal = jnp.tril(jnp.ones((n, n), dtype=mask.dtype)) # [n, n] m = jnp.multiply(pad, causal) # [n, n] m = jnp.broadcast_to(m[None], (self.num_heads, n, n)) # [H, n, n] return m.astype(jnp.float32) class TransformerLayer(eqx.Module): """One transformer block: attention followed by feed-forward.""" attention_block: AttentionBlock ff_block: FeedForwardBlock def __init__( self, hidden_size: int, intermediate_size: int, num_heads: int, dropout_rate: float, attention_dropout_rate: float, key: jax.random.PRNGKey, ): attn_key, ff_key = jax.random.split(key) self.attention_block = AttentionBlock( hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, key=attn_key, ) self.ff_block = FeedForwardBlock( hidden_size=hidden_size, intermediate_size=intermediate_size, dropout_rate=dropout_rate, key=ff_key, ) def __call__( self, inputs: Float[Array, "seq_len hidden_size"], mask: Int[Array, " seq_len"] | None = None, *, enable_dropout: bool = False, key: jax.Array | None = None, ) -> Float[Array, "seq_len hidden_size"]: attn_key, ff_key = (None, None) if key is None else jax.random.split(key) x = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key) n = x.shape[0] ff_keys = None if ff_key is None else jax.random.split(ff_key, n) x = jax.vmap(self.ff_block, in_axes=(0, None, 0))(x, enable_dropout, ff_keys) return x # --------------------------------------------------------------------------- # Encoder and top-level Generator # --------------------------------------------------------------------------- class Encoder(eqx.Module): """Stack of transformer layers over a snake-ordered spin sequence.""" embedder_block: EmbedderBlock layers: list[TransformerLayer] def __init__( self, state_size: int, lattice_size: int, embedding_size: int, hidden_size: int, intermediate_size: int, num_layers: int, num_heads: int, dropout_rate: float, attention_dropout_rate: float, key: jax.random.PRNGKey, ): emb_key, layer_key = jax.random.split(key) self.embedder_block = EmbedderBlock( state_size=state_size, lattice_size=lattice_size, embedding_size=embedding_size, hidden_size=hidden_size, dropout_rate=dropout_rate, key=emb_key, ) layer_keys = jax.random.split(layer_key, num_layers) self.layers = [ TransformerLayer( hidden_size=hidden_size, intermediate_size=intermediate_size, num_heads=num_heads, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, key=lk, ) for lk in layer_keys ] def __call__( self, states: Int[Array, " seq_len"], *, enable_dropout: bool = False, key: jax.Array | None = None, ) -> Float[Array, "seq_len hidden_size"]: emb_key, l_key = (None, None) if key is None else jax.random.split(key) x = self.embedder_block(states, enable_dropout=enable_dropout, key=emb_key) mask = jnp.ones_like(states, dtype=jnp.int32) # no padding; causal only for layer in self.layers: cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key) x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key) return x class Generator(eqx.Module): """Autoregressive transformer generator for Ising spin configurations. Input: token_ids — integer spin tokens {0=down, 1=up} in snake order. Output: logits — shape (seq_len, state_size), where logits[t] is the predicted distribution over the spin at position t+1 given positions 0..t. """ encoder: Encoder lm_head: eqx.nn.Linear dropout: eqx.nn.Dropout def __init__(self, config: Mapping, key: jax.random.PRNGKey): enc_key, head_key = jax.random.split(key) self.encoder = Encoder( state_size=config["state_size"], lattice_size=config["lattice_size"], embedding_size=config["hidden_size"], hidden_size=config["hidden_size"], intermediate_size=config["intermediate_size"], num_layers=config["num_hidden_layers"], num_heads=config["num_attention_heads"], dropout_rate=config["hidden_dropout_prob"], attention_dropout_rate=config["attention_probs_dropout_prob"], key=enc_key, ) self.lm_head = eqx.nn.Linear( in_features=config["hidden_size"], out_features=config["state_size"], key=head_key, ) self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"]) def __call__( self, inputs: dict[str, Int[Array, " seq_len"]], enable_dropout: bool = False, key: jax.random.PRNGKey = None, ) -> Float[Array, "seq_len state_size"]: e_key, d_key = (None, None) if key is None else jax.random.split(key) x = self.encoder(inputs["token_ids"], enable_dropout=enable_dropout, key=e_key) x = self.dropout(x, inference=not enable_dropout, key=d_key) return jax.vmap(self.lm_head)(x) # --------------------------------------------------------------------------- # Default configuration # --------------------------------------------------------------------------- gen_config = { "state_size": 2, # spin tokens: 0 (↓) or 1 (↑) "lattice_size": 32, # L×L lattice → L² = 1024 sequence length "hidden_size": 128, "num_hidden_layers": 2, "num_attention_heads": 2, "hidden_act": "gelu", "intermediate_size": 512, "hidden_dropout_prob": 0.1, "attention_probs_dropout_prob": 0.1, }