Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # /// script | |
| # dependencies = [ | |
| # "jax[cuda12]", | |
| # "equinox", | |
| # "optax", | |
| # "tqdm", | |
| # "jaxtyping", | |
| # ] | |
| # /// | |
| """Variational inference fine-tuning of the Ising Generator. | |
| Objective | |
| --------- | |
| Minimise the variational free energy | |
| F[q] = E_{s~q}[E(s)] β T Β· H[q] | |
| = T Β· E_{s~q}[ E(s)/T + log q(s) ] | |
| which equals KL(q β₯ p*) up to the constant log Z, where | |
| p*(s) β exp(βE(s)/T) is the Ising Boltzmann distribution. | |
| As F decreases, the model q approaches the correct physics. | |
| Gradient estimator (REINFORCE / score-function) | |
| ------------------------------------------------ | |
| β_ΞΈ F = T Β· E_{s~q}[( E(s)/T + log q(s) β b ) Β· β_ΞΈ log q(s)] | |
| where b = batch-mean reward is a zero-variance control variate. | |
| Per training step | |
| ----------------- | |
| 1. Sample a batch of configs from the current model q_ΞΈ (slow on CPU) | |
| 2. Compute Ising energy E(s) for each sample (fast, no model) | |
| 3. Compute log q(s) via a teacher-forced forward pass (fast, one pass) | |
| 4. Assemble reward R = E/T + log q, subtract baseline, backprop (fast) | |
| Speed note | |
| ---------- | |
| Step 1 dominates on CPU (~5 s / sample on a 32Γ32 lattice). On GPU it is | |
| typically 10β100Γ faster and VI training with batch_size β₯ 32 is practical. | |
| The --checkpoint flag warm-starts from a CE-pretrained model, which dramati- | |
| cally reduces the number of VI steps needed to converge. | |
| Monitoring | |
| ---------- | |
| At each step the following quantities are logged: | |
| e = β¨E/Nβ© mean energy per spin (converges to data ~β1.45) | |
| h = ββ¨log qβ©/N entropy per spin in nats (random init β 0.693) | |
| f = e β TΒ·h Helmholtz free energy per spin (we minimise this) | |
| |m| = β¨|Ξ£s / N|β© mean absolute magnetisation | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import equinox as eqx | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import optax | |
| from tqdm.auto import tqdm | |
| from model import Generator, gen_config, snake_order | |
| from sample import load_checkpoint, sample_batch | |
| # --------------------------------------------------------------------------- | |
| # Physical constants | |
| # --------------------------------------------------------------------------- | |
| J = 1.0 | |
| T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact 2D Ising T_c β 2.2692 | |
| # --------------------------------------------------------------------------- | |
| # Ising energy (pure JAX β no model parameters, no gradient needed) | |
| # --------------------------------------------------------------------------- | |
| def ising_energy_per_spin(token_ids: jax.Array) -> jax.Array: | |
| """Ising energy per spin from a snake-ordered token sequence {0, 1}. | |
| Hamiltonian: H = βJ Ξ£_{β¨ijβ©} s_i s_j with periodic boundary conditions. | |
| Returns a scalar. | |
| """ | |
| L = gen_config["lattice_size"] | |
| rows, cols = snake_order(L) # concrete at trace time | |
| spins = (token_ids * 2 - 1).astype(jnp.float32) # {0,1} β {β1,+1} | |
| grid = jnp.zeros((L, L)).at[jnp.asarray(rows), jnp.asarray(cols)].set(spins) | |
| right = jnp.roll(grid, -1, axis=1) | |
| down = jnp.roll(grid, -1, axis=0) | |
| return -J * (grid * right + grid * down).sum() / (L * L) | |
| # --------------------------------------------------------------------------- | |
| # VI loss (REINFORCE with per-batch baseline) | |
| # --------------------------------------------------------------------------- | |
| def compute_vi_loss( | |
| model, | |
| token_ids: jax.Array, | |
| T: float, | |
| ) -> tuple[jax.Array, dict]: | |
| """REINFORCE proxy loss for β_ΞΈ F[q]. | |
| The returned scalar is *not* the free energy β it is the REINFORCE | |
| surrogate whose gradient equals β_ΞΈ F/T. Use aux["f"] to track F. | |
| Parameters | |
| ---------- | |
| model : Generator | |
| token_ids : int array (batch, seq_len) samples drawn from q_ΞΈ | |
| T : target temperature | |
| Returns | |
| ------- | |
| (loss, aux), grads | |
| aux keys: e, h, f (per spin), |m|, reward_std (variance diagnostic) | |
| """ | |
| N = gen_config["lattice_size"] ** 2 | |
| # ββ log q(s) via teacher-forced forward pass βββββββββββββββββββββββββββββ | |
| # Disable dropout so we get the exact model log-probability. | |
| # in_axes=(0, None, None): vmap over batch; broadcast enable_dropout and key. | |
| logits = jax.vmap(model, in_axes=(0, None, None))( | |
| {"token_ids": token_ids}, False, None | |
| ) # (batch, seq_len, state_size) | |
| # Ξ£_t log p(s_t | s_{<t}) β summed over the sequence axis | |
| log_q = -optax.softmax_cross_entropy_with_integer_labels( | |
| logits[:, :-1, :], token_ids[:, 1:] | |
| ).sum(axis=-1) # (batch,) | |
| # ββ Ising energies ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # jax.lax.stop_gradient keeps energies out of the autodiff graph. | |
| energies = jax.lax.stop_gradient( | |
| jax.vmap(ising_energy_per_spin)(token_ids) # (batch,) | |
| ) | |
| # ββ REINFORCE reward R = E/T + log q(s) βββββββββββββββββββββββββββββββββ | |
| reward = jax.lax.stop_gradient(energies / T + log_q) | |
| baseline = reward.mean() | |
| # Whiten the reward: subtract mean AND divide by std. | |
| # This makes the gradient magnitude independent of reward scale and | |
| # prevents the high-variance collapse where all rewards are near-equal | |
| # (mode collapse β stdβ0 β gradientβ0 with mean-only baseline). | |
| reward_norm = (reward - baseline) / (reward.std() + 1e-8) | |
| # Proxy loss: β_ΞΈ loss = E_q[RΜ Β· β_ΞΈ log q] = β_ΞΈ (F/T) (up to scale) | |
| # Minimising this via gradient descent drives ΞΈ toward lower free energy. | |
| # NOTE: no negation β R is a cost to minimise, not a reward to maximise. | |
| loss = jnp.mean(jax.lax.stop_gradient(reward_norm) * log_q) | |
| # ββ Diagnostics (all stop-gradiented; no effect on training) βββββββββββββ | |
| e = energies.mean() # mean energy per spin | |
| h = -log_q.mean() / N # entropy per spin (nats) | |
| f = e - T * h # Helmholtz free energy per spin | |
| m = jnp.abs((token_ids * 2 - 1).astype(jnp.float32).mean(axis=-1)).mean() | |
| aux = { | |
| "e": e, | |
| "h": h, | |
| "f": f, | |
| "|m|": m, | |
| "reward_std": reward.std(), # REINFORCE variance diagnostic | |
| } | |
| return loss, aux | |
| # --------------------------------------------------------------------------- | |
| # Single training step (JIT-compiled; does NOT include sampling) | |
| # --------------------------------------------------------------------------- | |
| def vi_step( | |
| model, | |
| token_ids: jax.Array, | |
| opt_state, | |
| tx, | |
| T: float, | |
| ): | |
| """Compute VI loss + gradient and apply one optimiser update. | |
| Sampling is intentionally excluded so you can profile / replace it | |
| without re-compiling the gradient computation. | |
| """ | |
| (loss, aux), grads = compute_vi_loss(model, token_ids, T) | |
| updates, opt_state = tx.update(grads, opt_state, model) | |
| model = eqx.apply_updates(model, updates) | |
| return loss, aux, model, opt_state | |
| # --------------------------------------------------------------------------- | |
| # Sampling helper | |
| # --------------------------------------------------------------------------- | |
| _SAMPLE_BATCH = 4 # fixed call-site batch; changing triggers recompilation | |
| def draw_samples(model, n: int, key: jax.Array) -> jax.Array: | |
| """Sample n configurations from the model in fixed-size batches. | |
| Returns a jnp int32 array of shape (n, LΒ²). | |
| """ | |
| all_tokens = [] | |
| n_calls = -(-n // _SAMPLE_BATCH) # ceiling division | |
| for _ in range(n_calls): | |
| key, subkey = jax.random.split(key) | |
| all_tokens.append(np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey))) | |
| return jnp.asarray(np.concatenate(all_tokens)[:n]) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_args(): | |
| p = argparse.ArgumentParser( | |
| description="Variational inference fine-tuning of an Ising Generator.", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| p.add_argument("--checkpoint", type=Path, default=None, | |
| help="Warm-start from this .eqx file (strongly recommended).") | |
| p.add_argument("--output-checkpoint", type=Path, default=None, | |
| help="Save final model to this path.") | |
| p.add_argument("--num-steps", type=int, default=200) | |
| p.add_argument("--batch-size", type=int, default=16, | |
| help="Configurations sampled from q per gradient step.") | |
| p.add_argument("--learning-rate", type=float, default=1e-4) | |
| p.add_argument("--temperature", type=float, default=T_C, | |
| help=f"Target Boltzmann temperature (default: T_c β {T_C:.4f}).") | |
| p.add_argument("--log-every", type=int, default=1) | |
| p.add_argument("--seed", type=int, default=0) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| T = args.temperature | |
| key = jax.random.PRNGKey(args.seed) | |
| # ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.checkpoint is not None: | |
| print(f"Loading checkpoint from {args.checkpoint} β¦") | |
| model = load_checkpoint(args.checkpoint) | |
| else: | |
| print("Initialising model from scratch.") | |
| print(" Tip: use --checkpoint to warm-start from a CE-pretrained model.") | |
| key, model_key = jax.random.split(key) | |
| model = Generator(config=gen_config, key=model_key) | |
| # ββ Optimiser βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tx = optax.chain( | |
| optax.clip_by_global_norm(1.0), | |
| optax.adam(learning_rate=args.learning_rate), | |
| ) | |
| tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model)) | |
| opt_state = tx.init(model) | |
| L = gen_config["lattice_size"] | |
| print(f"\nVI training | steps={args.num_steps} " | |
| f"batch={args.batch_size} T={T:.4f} lr={args.learning_rate} L={L}") | |
| print(" columns: e = β¨E/Nβ© h = ββ¨log qβ©/N " | |
| "f = eβTΒ·h (minimised) |m| = mean |magnetisation|\n") | |
| # ββ Training loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tqdm(range(args.num_steps), unit="steps") as pbar: | |
| for step in pbar: | |
| # 1. Sample from current model (bottleneck on CPU) | |
| key, sample_key = jax.random.split(key) | |
| token_ids = draw_samples(model, args.batch_size, sample_key) | |
| # 2. VI gradient step (JIT-compiled teacher-forced forward pass) | |
| loss, aux, model, opt_state = vi_step( | |
| model, token_ids, opt_state, tx, T | |
| ) | |
| if step % args.log_every == 0: | |
| pbar.set_postfix( | |
| e = f"{float(aux['e']):.4f}", | |
| h = f"{float(aux['h']):.4f}", | |
| f = f"{float(aux['f']):.4f}", | |
| m = f"{float(aux['|m|']):.3f}", | |
| Rstd= f"{float(aux['reward_std']):.3f}", | |
| ) | |
| # ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.output_checkpoint is not None: | |
| args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True) | |
| eqx.tree_serialise_leaves(args.output_checkpoint, model) | |
| print(f"\nSaved checkpoint β {args.output_checkpoint}") | |
| if __name__ == "__main__": | |
| main() | |