import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import sys, jax sys.path.insert(0, '.') from sample import load_checkpoint, sample_batch, tokens_to_grids model = load_checkpoint('checkpoint.eqx') lattice_size = model.encoder.embedder_block.lattice_size key = jax.random.PRNGKey(42) print('Compiling and sampling 16 configurations...') tokens = sample_batch(model, 16, key) tokens = np.asarray(tokens) grids = tokens_to_grids(tokens, lattice_size) mags = grids.mean(axis=(1,2)) print(f'Magnetizations: {np.round(mags, 3)}') print(f'Mean |m|: {np.abs(mags).mean():.4f}') fig, axes = plt.subplots(2, 8, figsize=(14, 4)) for i, ax in enumerate(axes.flat): ax.imshow(grids[i], cmap='gray', vmin=-1, vmax=1, interpolation='nearest') ax.set_title(f'm={mags[i]:.2f}', fontsize=7) ax.axis('off') fig.suptitle('Sampled Ising configs (L=32, T=T_c, 2 epochs)', fontsize=10) plt.tight_layout() plt.savefig('samples.png', dpi=150, bbox_inches='tight') print('Saved samples.png')