bertran-yorro's picture
Initial upload: model, training scripts, Gradio app, data
5c85f22 verified
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sys, jax
sys.path.insert(0, '.')
from sample import load_checkpoint, sample_batch, tokens_to_grids
model = load_checkpoint('checkpoint.eqx')
lattice_size = model.encoder.embedder_block.lattice_size
key = jax.random.PRNGKey(42)
print('Compiling and sampling 16 configurations...')
tokens = sample_batch(model, 16, key)
tokens = np.asarray(tokens)
grids = tokens_to_grids(tokens, lattice_size)
mags = grids.mean(axis=(1,2))
print(f'Magnetizations: {np.round(mags, 3)}')
print(f'Mean |m|: {np.abs(mags).mean():.4f}')
fig, axes = plt.subplots(2, 8, figsize=(14, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(grids[i], cmap='gray', vmin=-1, vmax=1, interpolation='nearest')
ax.set_title(f'm={mags[i]:.2f}', fontsize=7)
ax.axis('off')
fig.suptitle('Sampled Ising configs (L=32, T=T_c, 2 epochs)', fontsize=10)
plt.tight_layout()
plt.savefig('samples.png', dpi=150, bbox_inches='tight')
print('Saved samples.png')