Spaces:
Sleeping
Sleeping
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import sys, jax | |
| sys.path.insert(0, '.') | |
| from sample import load_checkpoint, sample_batch, tokens_to_grids | |
| model = load_checkpoint('checkpoint.eqx') | |
| lattice_size = model.encoder.embedder_block.lattice_size | |
| key = jax.random.PRNGKey(42) | |
| print('Compiling and sampling 16 configurations...') | |
| tokens = sample_batch(model, 16, key) | |
| tokens = np.asarray(tokens) | |
| grids = tokens_to_grids(tokens, lattice_size) | |
| mags = grids.mean(axis=(1,2)) | |
| print(f'Magnetizations: {np.round(mags, 3)}') | |
| print(f'Mean |m|: {np.abs(mags).mean():.4f}') | |
| fig, axes = plt.subplots(2, 8, figsize=(14, 4)) | |
| for i, ax in enumerate(axes.flat): | |
| ax.imshow(grids[i], cmap='gray', vmin=-1, vmax=1, interpolation='nearest') | |
| ax.set_title(f'm={mags[i]:.2f}', fontsize=7) | |
| ax.axis('off') | |
| fig.suptitle('Sampled Ising configs (L=32, T=T_c, 2 epochs)', fontsize=10) | |
| plt.tight_layout() | |
| plt.savefig('samples.png', dpi=150, bbox_inches='tight') | |
| print('Saved samples.png') |