ising-transformer / model.py
bertran-yorro's picture
Initial upload: model, training scripts, Gradio app, data
5c85f22 verified
"""Transformer model for autoregressive Ising spin generation.
Architecture: causal (GPT-style) transformer with per-site positional
embeddings in snake (boustrophedon) order. The model is trained to maximise
p(s_0, s_1, ..., s_{N-1}) = ∏_t p(s_t | s_0, ..., s_{t-1}), where the spin
sites are visited in snake order over the L×L lattice.
"""
from collections.abc import Mapping
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float, Int
def snake_order(size: int) -> tuple[np.ndarray, np.ndarray]:
"""Return (rows, cols) index arrays traversing an L×L grid in snake order.
Even rows go left-to-right; odd rows go right-to-left. The returned
arrays have length size² and implement numpy advanced indexing:
grid[rows, cols] → 1-D sequence in snake order
grid[rows, cols] = seq → scatter a sequence back to the grid
"""
if size <= 0:
raise ValueError("size must be positive")
rows, cols = [], []
for row in range(size):
columns = range(size) if row % 2 == 0 else range(size - 1, -1, -1)
for col in columns:
rows.append(row)
cols.append(col)
return np.array(rows), np.array(cols)
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
class EmbedderBlock(eqx.Module):
"""Spin-state + lattice-position embedder.
Each position in the snake-order sequence gets three embeddings summed:
• a learned spin-state embedding (token ∈ {0, 1})
• a learned row-position embedding
• a learned column-position embedding
The row/column indices are derived from `snake_order` at trace time, so
they fold to compile-time constants — no array model-parameters are stored.
"""
state_embedder: eqx.nn.Embedding
row_embedder: eqx.nn.Embedding
column_embedder: eqx.nn.Embedding
layernorm: eqx.nn.LayerNorm
dropout: eqx.nn.Dropout
lattice_size: int = eqx.field(static=True)
def __init__(
self,
state_size: int,
lattice_size: int,
embedding_size: int,
hidden_size: int,
dropout_rate: float,
key: jax.random.PRNGKey,
):
state_key, row_key, col_key = jax.random.split(key, 3)
self.state_embedder = eqx.nn.Embedding(
num_embeddings=state_size, embedding_size=embedding_size, key=state_key
)
self.row_embedder = eqx.nn.Embedding(
num_embeddings=lattice_size, embedding_size=embedding_size, key=row_key
)
self.column_embedder = eqx.nn.Embedding(
num_embeddings=lattice_size, embedding_size=embedding_size, key=col_key
)
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
self.dropout = eqx.nn.Dropout(dropout_rate)
self.lattice_size = lattice_size
def __call__(
self,
states: Int[Array, " seq_len"],
enable_dropout: bool = False,
key: jax.Array | None = None,
) -> Float[Array, "seq_len hidden_size"]:
rows, cols = snake_order(self.lattice_size) # concrete at trace time
x_states = jax.vmap(self.state_embedder)(states)
x_rows = jax.vmap(self.row_embedder)(jnp.asarray(rows))
x_cols = jax.vmap(self.column_embedder)(jnp.asarray(cols))
x = x_states + x_rows + x_cols
x = jax.vmap(self.layernorm)(x)
x = self.dropout(x, inference=not enable_dropout, key=key)
return x
class FeedForwardBlock(eqx.Module):
"""Position-wise feed-forward block with residual connection."""
mlp: eqx.nn.Linear
output: eqx.nn.Linear
layernorm: eqx.nn.LayerNorm
dropout: eqx.nn.Dropout
def __init__(
self,
hidden_size: int,
intermediate_size: int,
dropout_rate: float,
key: jax.random.PRNGKey,
):
mlp_key, out_key = jax.random.split(key)
self.mlp = eqx.nn.Linear(hidden_size, intermediate_size, key=mlp_key)
self.output = eqx.nn.Linear(intermediate_size, hidden_size, key=out_key)
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
inputs: Float[Array, " hidden_size"],
enable_dropout: bool = False,
key: jax.Array | None = None,
) -> Float[Array, " hidden_size"]:
x = jax.nn.gelu(self.mlp(inputs))
x = self.output(x)
x = self.dropout(x, inference=not enable_dropout, key=key)
x = x + inputs
x = self.layernorm(x)
return x
class AttentionBlock(eqx.Module):
"""Multi-head self-attention with causal (lower-triangular) mask."""
attention: eqx.nn.MultiheadAttention
layernorm: eqx.nn.LayerNorm
dropout: eqx.nn.Dropout
num_heads: int = eqx.field(static=True)
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
key: jax.random.PRNGKey,
):
self.num_heads = num_heads
self.attention = eqx.nn.MultiheadAttention(
num_heads=num_heads,
query_size=hidden_size,
use_query_bias=True,
use_key_bias=True,
use_value_bias=True,
use_output_bias=True,
dropout_p=attention_dropout_rate,
key=key,
)
self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
inputs: Float[Array, "seq_len hidden_size"],
mask: Int[Array, " seq_len"] | None,
enable_dropout: bool = False,
key: jax.random.PRNGKey = None,
) -> Float[Array, "seq_len hidden_size"]:
attn_key, drop_key = (None, None) if key is None else jax.random.split(key)
if mask is not None:
mask = self._causal_mask(mask)
x = self.attention(
query=inputs, key_=inputs, value=inputs,
mask=mask, inference=not enable_dropout, key=attn_key,
)
x = self.dropout(x, inference=not enable_dropout, key=drop_key)
x = x + inputs
x = jax.vmap(self.layernorm)(x)
return x
def _causal_mask(
self, mask: Int[Array, " seq_len"]
) -> Float[Array, "num_heads seq_len seq_len"]:
"""Lower-triangular mask combined with a padding mask."""
n = mask.shape[0]
pad = jnp.multiply(mask[:, None], mask[None, :]) # [n, n]
causal = jnp.tril(jnp.ones((n, n), dtype=mask.dtype)) # [n, n]
m = jnp.multiply(pad, causal) # [n, n]
m = jnp.broadcast_to(m[None], (self.num_heads, n, n)) # [H, n, n]
return m.astype(jnp.float32)
class TransformerLayer(eqx.Module):
"""One transformer block: attention followed by feed-forward."""
attention_block: AttentionBlock
ff_block: FeedForwardBlock
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
key: jax.random.PRNGKey,
):
attn_key, ff_key = jax.random.split(key)
self.attention_block = AttentionBlock(
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
key=attn_key,
)
self.ff_block = FeedForwardBlock(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dropout_rate=dropout_rate,
key=ff_key,
)
def __call__(
self,
inputs: Float[Array, "seq_len hidden_size"],
mask: Int[Array, " seq_len"] | None = None,
*,
enable_dropout: bool = False,
key: jax.Array | None = None,
) -> Float[Array, "seq_len hidden_size"]:
attn_key, ff_key = (None, None) if key is None else jax.random.split(key)
x = self.attention_block(inputs, mask, enable_dropout=enable_dropout, key=attn_key)
n = x.shape[0]
ff_keys = None if ff_key is None else jax.random.split(ff_key, n)
x = jax.vmap(self.ff_block, in_axes=(0, None, 0))(x, enable_dropout, ff_keys)
return x
# ---------------------------------------------------------------------------
# Encoder and top-level Generator
# ---------------------------------------------------------------------------
class Encoder(eqx.Module):
"""Stack of transformer layers over a snake-ordered spin sequence."""
embedder_block: EmbedderBlock
layers: list[TransformerLayer]
def __init__(
self,
state_size: int,
lattice_size: int,
embedding_size: int,
hidden_size: int,
intermediate_size: int,
num_layers: int,
num_heads: int,
dropout_rate: float,
attention_dropout_rate: float,
key: jax.random.PRNGKey,
):
emb_key, layer_key = jax.random.split(key)
self.embedder_block = EmbedderBlock(
state_size=state_size,
lattice_size=lattice_size,
embedding_size=embedding_size,
hidden_size=hidden_size,
dropout_rate=dropout_rate,
key=emb_key,
)
layer_keys = jax.random.split(layer_key, num_layers)
self.layers = [
TransformerLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
key=lk,
)
for lk in layer_keys
]
def __call__(
self,
states: Int[Array, " seq_len"],
*,
enable_dropout: bool = False,
key: jax.Array | None = None,
) -> Float[Array, "seq_len hidden_size"]:
emb_key, l_key = (None, None) if key is None else jax.random.split(key)
x = self.embedder_block(states, enable_dropout=enable_dropout, key=emb_key)
mask = jnp.ones_like(states, dtype=jnp.int32) # no padding; causal only
for layer in self.layers:
cl_key, l_key = (None, None) if l_key is None else jax.random.split(l_key)
x = layer(x, mask, enable_dropout=enable_dropout, key=cl_key)
return x
class Generator(eqx.Module):
"""Autoregressive transformer generator for Ising spin configurations.
Input: token_ids — integer spin tokens {0=down, 1=up} in snake order.
Output: logits — shape (seq_len, state_size), where logits[t] is the
predicted distribution over the spin at position t+1
given positions 0..t.
"""
encoder: Encoder
lm_head: eqx.nn.Linear
dropout: eqx.nn.Dropout
def __init__(self, config: Mapping, key: jax.random.PRNGKey):
enc_key, head_key = jax.random.split(key)
self.encoder = Encoder(
state_size=config["state_size"],
lattice_size=config["lattice_size"],
embedding_size=config["hidden_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_layers=config["num_hidden_layers"],
num_heads=config["num_attention_heads"],
dropout_rate=config["hidden_dropout_prob"],
attention_dropout_rate=config["attention_probs_dropout_prob"],
key=enc_key,
)
self.lm_head = eqx.nn.Linear(
in_features=config["hidden_size"],
out_features=config["state_size"],
key=head_key,
)
self.dropout = eqx.nn.Dropout(config["hidden_dropout_prob"])
def __call__(
self,
inputs: dict[str, Int[Array, " seq_len"]],
enable_dropout: bool = False,
key: jax.random.PRNGKey = None,
) -> Float[Array, "seq_len state_size"]:
e_key, d_key = (None, None) if key is None else jax.random.split(key)
x = self.encoder(inputs["token_ids"], enable_dropout=enable_dropout, key=e_key)
x = self.dropout(x, inference=not enable_dropout, key=d_key)
return jax.vmap(self.lm_head)(x)
# ---------------------------------------------------------------------------
# Default configuration
# ---------------------------------------------------------------------------
gen_config = {
"state_size": 2, # spin tokens: 0 (↓) or 1 (↑)
"lattice_size": 32, # L×L lattice → L² = 1024 sequence length
"hidden_size": 128,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"hidden_act": "gelu",
"intermediate_size": 512,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
}