Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # /// script | |
| # dependencies = [ | |
| # "jax[cuda12]", | |
| # "equinox", | |
| # "optax", | |
| # "einops", | |
| # "tqdm", | |
| # "jaxtyping", | |
| # ] | |
| # /// | |
| """Training script for the Ising spin Generator. | |
| Usage: | |
| python train.py [--epochs N] [--batch-size B] [--learning-rate LR] | |
| [--data path/to/spins.npy] [--output-checkpoint model.eqx] | |
| """ | |
| import argparse | |
| import functools | |
| from pathlib import Path | |
| import einops | |
| import equinox as eqx | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import optax | |
| from tqdm.auto import tqdm | |
| from model import Generator, gen_config, snake_order | |
| # --------------------------------------------------------------------------- | |
| # Data loading | |
| # --------------------------------------------------------------------------- | |
| def load_ising_data( | |
| path: Path, train_frac: float = 0.9 | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| """Load spins.npy, map {-1,1} → {0,1}, flatten with snake ordering. | |
| Returns ``(train_tokens, val_tokens)``, each ``(N, L²)`` int32. | |
| """ | |
| spins = np.load(path) # (N, L, L) int8 | |
| lattice_size = spins.shape[1] | |
| tokens = (spins.astype(np.int32) + 1) // 2 # (N, L, L), values in {0,1} | |
| rows, cols = snake_order(lattice_size) | |
| tokens = tokens[:, rows, cols] # (N, L²) | |
| n_train = int(len(tokens) * train_frac) | |
| return tokens[:n_train], tokens[n_train:] | |
| # --------------------------------------------------------------------------- | |
| # Batch preparation | |
| # --------------------------------------------------------------------------- | |
| def prepare_batch(batch: np.ndarray, num_devices: int) -> dict: | |
| """Reshape ``(batch, seq)`` → ``(devices, batch//devices, seq)`` for pmap.""" | |
| token_ids = einops.rearrange( | |
| batch, | |
| "(devices batch) seq -> devices batch seq", | |
| devices=num_devices, | |
| ) | |
| return {"token_ids": token_ids} | |
| # --------------------------------------------------------------------------- | |
| # Training / eval steps | |
| # --------------------------------------------------------------------------- | |
| def compute_loss(model, inputs, key): | |
| """Autoregressive cross-entropy: logits[:, :-1] predicts token_ids[:, 1:].""" | |
| batch_size = inputs["token_ids"].shape[0] | |
| keys = jax.random.split(key, batch_size) | |
| logits = jax.vmap(model, in_axes=(0, None, 0))(inputs, True, keys) | |
| return jnp.mean( | |
| optax.softmax_cross_entropy_with_integer_labels( | |
| logits=logits[:, :-1, :], | |
| labels=inputs["token_ids"][:, 1:], | |
| ) | |
| ) | |
| def make_step(model, inputs, opt_state, key, tx): | |
| key, new_key = jax.random.split(key) | |
| loss, grads = compute_loss(model, inputs, key) | |
| grads = jax.lax.pmean(grads, axis_name="devices") | |
| updates, opt_state = tx.update(grads, opt_state, model) | |
| model = eqx.apply_updates(model, updates) | |
| return loss, model, opt_state, new_key | |
| def make_eval_step(model, inputs): | |
| """Per-device mean NLL (nats/token), called inside pmap.""" | |
| logits = jax.vmap(functools.partial(model, enable_dropout=False))(inputs) | |
| return jnp.mean( | |
| optax.softmax_cross_entropy_with_integer_labels( | |
| logits=logits[:, :-1, :], | |
| labels=inputs["token_ids"][:, 1:], | |
| ) | |
| ) | |
| p_make_eval_step = eqx.filter_pmap(make_eval_step) | |
| # --------------------------------------------------------------------------- | |
| # pmap helpers | |
| # --------------------------------------------------------------------------- | |
| def replicate_for_pmap(value, devices): | |
| mesh = jax.sharding.Mesh(np.asarray(devices), ("devices",)) | |
| sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("devices")) | |
| def replicate_leaf(leaf): | |
| leaf = jnp.asarray(leaf) | |
| leaf = jnp.broadcast_to(leaf, (len(devices),) + leaf.shape) | |
| return jax.device_put(leaf, sharding) | |
| return jax.tree.map(replicate_leaf, value) | |
| def unreplicate_from_pmap(value): | |
| return jax.tree.map(lambda leaf: leaf[0], value) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Train an Ising spin generator.") | |
| p.add_argument("--epochs", type=int, default=10) | |
| p.add_argument("--batch-size", type=int, default=32) | |
| p.add_argument("--learning-rate", type=float, default=1e-4) | |
| p.add_argument("--max-train-steps", type=int, default=None) | |
| p.add_argument("--max-eval-batches", type=int, default=None) | |
| p.add_argument("--data", type=Path, | |
| default=Path(__file__).parent / "spins.npy") | |
| p.add_argument("--output-checkpoint", type=Path, default=None) | |
| p.add_argument("--seed", type=int, default=5678) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| num_devices = jax.device_count() | |
| print(f"JAX devices: {jax.devices()}") | |
| assert args.batch_size % num_devices == 0, ( | |
| "batch-size must be a multiple of the number of devices" | |
| ) | |
| key = jax.random.PRNGKey(args.seed) | |
| model_key, train_key = jax.random.split(key) | |
| model = Generator(config=gen_config, key=model_key) | |
| train_tokens, val_tokens = load_ising_data(args.data) | |
| print( | |
| f"Train: {len(train_tokens):,} Val: {len(val_tokens):,} " | |
| f"Seq len: {train_tokens.shape[1]}" | |
| ) | |
| tx = optax.chain( | |
| optax.clip_by_global_norm(1.0), | |
| optax.adam(learning_rate=args.learning_rate), | |
| ) | |
| # Mask to float-only leaves so integer bookkeeping fields are excluded. | |
| tx = optax.masked(tx, jax.tree.map(eqx.is_inexact_array, model)) | |
| opt_state = tx.init(model) | |
| p_make_step = eqx.filter_pmap( | |
| functools.partial(make_step, tx=tx), axis_name="devices" | |
| ) | |
| devices = jax.local_devices() | |
| opt_state = replicate_for_pmap(opt_state, devices) | |
| model = replicate_for_pmap(model, devices) | |
| train_key = replicate_for_pmap(train_key, devices) | |
| global_step = 0 | |
| for epoch in range(args.epochs): | |
| rng = np.random.default_rng(args.seed + epoch) | |
| shuffled = train_tokens[rng.permutation(len(train_tokens))] | |
| num_batches = len(shuffled) // args.batch_size | |
| if args.max_train_steps is not None: | |
| num_batches = min(num_batches, max(args.max_train_steps - global_step, 0)) | |
| with tqdm(range(num_batches), unit="steps", | |
| desc=f"Epoch {epoch + 1}/{args.epochs}") as pbar: | |
| for step in pbar: | |
| batch = shuffled[step * args.batch_size : (step + 1) * args.batch_size] | |
| inputs = prepare_batch(batch, num_devices) | |
| loss, model, opt_state, train_key = p_make_step( | |
| model, inputs, opt_state, train_key | |
| ) | |
| global_step += 1 | |
| pbar.set_postfix(loss=float(np.sum(loss))) | |
| if args.max_train_steps and global_step >= args.max_train_steps: | |
| break | |
| if args.max_train_steps and global_step >= args.max_train_steps: | |
| break | |
| # ---- validation ---- | |
| num_val = len(val_tokens) // args.batch_size | |
| if args.max_eval_batches is not None: | |
| num_val = min(num_val, args.max_eval_batches) | |
| val_losses = [] | |
| for step in tqdm(range(num_val), unit="steps", desc="Validation"): | |
| batch = val_tokens[step * args.batch_size : (step + 1) * args.batch_size] | |
| inputs = prepare_batch(batch, num_devices) | |
| val_losses.append(float(np.mean(p_make_eval_step(model, inputs)))) | |
| print(f"Val NLL: {np.mean(val_losses):.4f} nats/token") | |
| if args.output_checkpoint is not None: | |
| args.output_checkpoint.parent.mkdir(parents=True, exist_ok=True) | |
| eqx.tree_serialise_leaves( | |
| args.output_checkpoint, unreplicate_from_pmap(model) | |
| ) | |
| print(f"Saved checkpoint → {args.output_checkpoint}") | |
| if __name__ == "__main__": | |
| main() | |