Spaces:
Sleeping
Sleeping
| """Transformer model for autoregressive Ising spin generation. | |
| Architecture: causal (GPT-style) transformer with per-site positional | |
| embeddings in snake (boustrophedon) order. The model is trained to maximise | |
| p(s_0, s_1, ..., s_{N-1}) = ∏_t p(s_t | s_0, ..., s_{t-1}), where the spin | |
| sites are visited in snake order over the L×L lattice. | |
| """ | |
| from collections.abc import Mapping | |
| import equinox as eqx | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from jaxtyping import Array, Float, Int | |
| def snake_order(size: int) -> tuple[np.ndarray, np.ndarray]: | |
| """Return (rows, cols) index arrays traversing an L×L grid in snake order. | |
| Even rows go left-to-right; odd rows go right-to-left. The returned | |
| arrays have length size² and implement numpy advanced indexing: | |
| grid[rows, cols] → 1-D sequence in snake order | |
| grid[rows, cols] = seq → scatter a sequence back to the grid | |
| """ | |
| if size <= 0: | |
| raise ValueError("size must be positive") | |
| rows, cols = [], [] | |
| for row in range(size): | |
| columns = range(size) if row % 2 == 0 else range(size - 1, -1, -1) | |
| for col in columns: | |
| rows.append(row) | |
| cols.append(col) | |
| return np.array(rows), np.array(cols) | |
| # --------------------------------------------------------------------------- | |
| # Building blocks | |
| # --------------------------------------------------------------------------- | |
| class EmbedderBlock(eqx.Module): | |
| """Spin-state + lattice-position embedder. | |
| Each position in the snake-order sequence gets three embeddings summed: | |
| • a learned spin-state embedding (token ∈ {0, 1}) | |
| • a learned row-position embedding | |
| • a learned column-position embedding | |
| The row/column indices are derived from `snake_order` at trace time, so | |
| they fold to compile-time constants — no array model-parameters are stored. | |
| """ | |
| state_embedder: eqx.nn.Embedding | |
| row_embedder: eqx.nn.Embedding | |
| column_embedder: eqx.nn.Embedding | |
| layernorm: eqx.nn.LayerNorm | |
| dropout: eqx.nn.Dropout | |
| lattice_size: int = eqx.field(static=True) | |
| def __init__( | |
| self, | |
| state_size: int, | |
| lattice_size: int, | |
| embedding_size: int, | |
| hidden_size: int, | |
| dropout_rate: float, | |
| key: jax.random.PRNGKey, | |
| ): | |
| state_key, row_key, col_key = jax.random.split(key, 3) | |
| self.state_embedder = eqx.nn.Embedding( | |
| num_embeddings=state_size, embedding_size=embedding_size, key=state_key | |
| ) | |
| self.row_embedder = eqx.nn.Embedding( | |
| num_embeddings=lattice_size, embedding_size=embedding_size, key=row_key | |
| ) | |
| self.column_embedder = eqx.nn.Embedding( | |
| num_embeddings=lattice_size, embedding_size=embedding_size, key=col_key | |
| ) | |
| self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) | |
| self.dropout = eqx.nn.Dropout(dropout_rate) | |
| self.lattice_size = lattice_size | |
| def __call__( | |
| self, | |
| states: Int[Array, " seq_len"], | |
| enable_dropout: bool = False, | |
| key: jax.Array | None = None, | |
| ) -> Float[Array, "seq_len hidden_size"]: | |
| rows, cols = snake_order(self.lattice_size) # concrete at trace time | |
| x_states = jax.vmap(self.state_embedder)(states) | |
| x_rows = jax.vmap(self.row_embedder)(jnp.asarray(rows)) | |
| x_cols = jax.vmap(self.column_embedder)(jnp.asarray(cols)) | |
| x = x_states + x_rows + x_cols | |
| x = jax.vmap(self.layernorm)(x) | |
| x = self.dropout(x, inference=not enable_dropout, key=key) | |
| return x | |
| class FeedForwardBlock(eqx.Module): | |
| """Position-wise feed-forward block with residual connection.""" | |
| mlp: eqx.nn.Linear | |
| output: eqx.nn.Linear | |
| layernorm: eqx.nn.LayerNorm | |
| dropout: eqx.nn.Dropout | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| dropout_rate: float, | |
| key: jax.random.PRNGKey, | |
| ): | |
| mlp_key, out_key = jax.random.split(key) | |
| self.mlp = eqx.nn.Linear(hidden_size, intermediate_size, key=mlp_key) | |
| self.output = eqx.nn.Linear(intermediate_size, hidden_size, key=out_key) | |
| self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) | |
| self.dropout = eqx.nn.Dropout(dropout_rate) | |
| def __call__( | |
| self, | |
| inputs: Float[Array, " hidden_size"], | |
| enable_dropout: bool = False, | |
| key: jax.Array | None = None, | |
| ) -> Float[Array, " hidden_size"]: | |
| x = jax.nn.gelu(self.mlp(inputs)) | |
| x = self.output(x) | |
| x = self.dropout(x, inference=not enable_dropout, key=key) | |
| x = x + inputs | |
| x = self.layernorm(x) | |
| return x | |
| class AttentionBlock(eqx.Module): | |
| """Multi-head self-attention with causal (lower-triangular) mask.""" | |
| attention: eqx.nn.MultiheadAttention | |
| layernorm: eqx.nn.LayerNorm | |
| dropout: eqx.nn.Dropout | |
| num_heads: int = eqx.field(static=True) | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| dropout_rate: float, | |
| attention_dropout_rate: float, | |
| key: jax.random.PRNGKey, | |
| ): | |
| self.num_heads = num_heads | |
| self.attention = eqx.nn.MultiheadAttention( | |
| num_heads=num_heads, | |
| query_size=hidden_size, | |
| use_query_bias=True, | |
| use_key_bias=True, | |
| use_value_bias=True, | |
| use_output_bias=True, | |
| dropout_p=attention_dropout_rate, | |
| key=key, | |
| ) | |
| self.layernorm = eqx.nn.LayerNorm(shape=hidden_size) | |
| self.dropout = eqx.nn.Dropout(dropout_rate) | |
| def __call__( | |
| self, | |
| inputs: Float[Array, "seq_len hidden_size"], | |
| mask: Int[Array, " seq_len"] | None, | |
| enable_dropout: bool = False, | |
| key: jax.random.PRNGKey = None, | |
| ) -> Float[Array, "seq_len hidden_size"]: | |
| attn_key, drop_key = (None, None) if key is None else jax.random.split(key) | |
| if mask is not None: | |
| mask = self._causal_mask(mask) | |
| x = self.attention( | |
| query=inputs, key_=inputs, value=inputs, | |
| mask=mask, inference=not enable_dropout, key=attn_key, | |
| ) | |
| x = self.dropout(x, inference=not enable_dropout, key=drop_key) | |
| x = x + inputs | |
| x = jax.vmap(self.layernorm)(x) | |
| return x | |
| def _causal_mask( | |
| self, mask: Int[Array, " seq_len"] | |
| ) -> Float[Array, "num_heads seq_len seq_len"]: | |
| """Lower-triangular mask combined with a padding mask.""" | |
| n = mask.shape[0] | |
| pad = jnp.multiply(mask[:, None], mask[None, :]) # [n, n] | |
| causal = jnp.tril(jnp.ones((n, n), dtype=mask.dtype)) # [n, n] | |
| m = jnp.multiply(pad, causal) # [n, n] | |
| m = jnp.broadcast_to(m[None], (self.num_heads, n, n)) # [H, n, n] | |
| return m.astype(jnp.float32) | |
| class TransformerLayer(eqx.Module): | |
| """One transformer block: attention followed by feed-forward.""" | |
| attention_block: AttentionBlock | |
| ff_block: FeedForwardBlock | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| num_heads: int, | |
| dropout_rate: float, | |
| attention_dropout_rate: float, | |
| key: jax.random.PRNGKey, | |
| ): | |
| attn_key, ff_key = jax.random.split(key) | |
| self.attention_block = AttentionBlock( | |
| hidden_size=hidden_size, | |
| num_heads=num_heads, | |
| dropout_rate=dropout_rate, | |
| attention_dropout_rate=attention_dropout_rate, | |
| key=attn_key, | |
| ) | |
| self.ff_block = FeedForwardBlock( | |
| hidden_size=hidden_size, | |
| intermediate_size=intermediate_size, | |
| dropout_rate=dropout_rate, | |
| key=ff_key, | |
| ) | |
| def __call__( | |
| self, | |
| inputs: Float[Array, "seq_len hidden_size"], | |
| mask: Int[Array, " seq_len"] | None = None, | |
| *, | |
| enable_dropout: bool = False, | |
| key: jax.Array | None = None, | |
| ) -> Float[Array, "seq_len hidden_size"]: | |
| attn_key, ff_key = (None, None) if key is None else jax.random.split(key) | |
| x = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key) | |
| n = x.shape[0] | |
| ff_keys = None if ff_key is None else jax.random.split(ff_key, n) | |
| x = jax.vmap(self.ff_block, in_axes=(0, None, 0))(x, enable_dropout, ff_keys) | |
| return x | |
| # --------------------------------------------------------------------------- | |
| # Encoder and top-level Generator | |
| # --------------------------------------------------------------------------- | |
| class Encoder(eqx.Module): | |
| """Stack of transformer layers over a snake-ordered spin sequence.""" | |
| embedder_block: EmbedderBlock | |
| layers: list[TransformerLayer] | |
| def __init__( | |
| self, | |
| state_size: int, | |
| lattice_size: int, | |
| embedding_size: int, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout_rate: float, | |
| attention_dropout_rate: float, | |
| key: jax.random.PRNGKey, | |
| ): | |
| emb_key, layer_key = jax.random.split(key) | |
| self.embedder_block = EmbedderBlock( | |
| state_size=state_size, | |
| lattice_size=lattice_size, | |
| embedding_size=embedding_size, | |
| hidden_size=hidden_size, | |
| dropout_rate=dropout_rate, | |
| key=emb_key, | |
| ) | |
| layer_keys = jax.random.split(layer_key, num_layers) | |
| self.layers = [ | |
| TransformerLayer( | |
| hidden_size=hidden_size, | |
| intermediate_size=intermediate_size, | |
| num_heads=num_heads, | |
| dropout_rate=dropout_rate, | |
| attention_dropout_rate=attention_dropout_rate, | |
| key=lk, | |
| ) | |
| for lk in layer_keys | |
| ] | |
| def __call__( | |
| self, | |
| states: Int[Array, " seq_len"], | |
| *, | |
| enable_dropout: bool = False, | |
| key: jax.Array | None = None, | |
| ) -> Float[Array, "seq_len hidden_size"]: | |
| emb_key, l_key = (None, None) if key is None else jax.random.split(key) | |
| x = self.embedder_block(states, enable_dropout=enable_dropout, key=emb_key) | |
| mask = jnp.ones_like(states, dtype=jnp.int32) # no padding; causal only | |
| for layer in self.layers: | |
| cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key) | |
| x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key) | |
| return x | |
| class Generator(eqx.Module): | |
| """Autoregressive transformer generator for Ising spin configurations. | |
| Input: token_ids — integer spin tokens {0=down, 1=up} in snake order. | |
| Output: logits — shape (seq_len, state_size), where logits[t] is the | |
| predicted distribution over the spin at position t+1 | |
| given positions 0..t. | |
| """ | |
| encoder: Encoder | |
| lm_head: eqx.nn.Linear | |
| dropout: eqx.nn.Dropout | |
| def __init__(self, config: Mapping, key: jax.random.PRNGKey): | |
| enc_key, head_key = jax.random.split(key) | |
| self.encoder = Encoder( | |
| state_size=config["state_size"], | |
| lattice_size=config["lattice_size"], | |
| embedding_size=config["hidden_size"], | |
| hidden_size=config["hidden_size"], | |
| intermediate_size=config["intermediate_size"], | |
| num_layers=config["num_hidden_layers"], | |
| num_heads=config["num_attention_heads"], | |
| dropout_rate=config["hidden_dropout_prob"], | |
| attention_dropout_rate=config["attention_probs_dropout_prob"], | |
| key=enc_key, | |
| ) | |
| self.lm_head = eqx.nn.Linear( | |
| in_features=config["hidden_size"], | |
| out_features=config["state_size"], | |
| key=head_key, | |
| ) | |
| self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"]) | |
| def __call__( | |
| self, | |
| inputs: dict[str, Int[Array, " seq_len"]], | |
| enable_dropout: bool = False, | |
| key: jax.random.PRNGKey = None, | |
| ) -> Float[Array, "seq_len state_size"]: | |
| e_key, d_key = (None, None) if key is None else jax.random.split(key) | |
| x = self.encoder(inputs["token_ids"], enable_dropout=enable_dropout, key=e_key) | |
| x = self.dropout(x, inference=not enable_dropout, key=d_key) | |
| return jax.vmap(self.lm_head)(x) | |
| # --------------------------------------------------------------------------- | |
| # Default configuration | |
| # --------------------------------------------------------------------------- | |
| gen_config = { | |
| "state_size": 2, # spin tokens: 0 (↓) or 1 (↑) | |
| "lattice_size": 32, # L×L lattice → L² = 1024 sequence length | |
| "hidden_size": 128, | |
| "num_hidden_layers": 2, | |
| "num_attention_heads": 2, | |
| "hidden_act": "gelu", | |
| "intermediate_size": 512, | |
| "hidden_dropout_prob": 0.1, | |
| "attention_probs_dropout_prob": 0.1, | |
| } | |