#!/usr/bin/env python # /// script # dependencies = [ # "jax[cuda12]", # "equinox", # "scipy", # "jaxtyping", # ] # /// """Evaluate a trained Generator against held-out test samples. For each configuration we compute an 11-dimensional feature vector of physical observables. The Mahalanobis distance between the real and generated feature distributions gives a single scalar measure of model quality. Per-sample feature vector -------------------------- m, m^2, |m| magnetisation and its moments e, e^2 nearest-neighbour energy per spin (periodic BC) C(1..8) connected two-point correlation at r = 1, 2, 4, 8 s_mean/N mean cluster size (4-connected, open BC) s_max/N largest cluster size Ensemble statistics (printed for reference, not part of Mahalanobis) ---------------------------------------------------------------------- chi = N · Var(m) / T magnetic susceptibility C_v = N · Var(e) / T² specific heat U4 = 1 − /(3^2) Binder cumulant → 2/3 in ordered phase → 0 in disordered phase ≈ 0.47 at T_c for 2D Ising (L→∞) Distance -------- D = sqrt( Δμ^T Σ_real^{-1} Δμ ) where Δμ = μ_gen − μ_real and Σ_real is the sample covariance of the real test features. Per-feature z-scores Δμ_i / σ_real_i are also reported so you can see which observables deviate most. """ import argparse from pathlib import Path import numpy as np import scipy.ndimage import jax from tqdm.auto import tqdm from model import gen_config from sample import load_checkpoint, sample_batch, tokens_to_grids from train import load_ising_data # --------------------------------------------------------------------------- # Physical constants # --------------------------------------------------------------------------- J = 1.0 T_C = 2.0 / np.log(1.0 + np.sqrt(2.0)) # exact: 2J / ln(1+√2) ≈ 2.2692 FEATURE_NAMES = [ "m", "m^2", "|m|", "e", "e^2", "C(r=1)", "C(r=2)", "C(r=4)", "C(r=8)", "s_mean/N", "s_max/N", ] # --------------------------------------------------------------------------- # Per-sample observables # --------------------------------------------------------------------------- def energy_per_spin(grid: np.ndarray) -> float: """Nearest-neighbour energy density with periodic boundary conditions. E/N = −J/N · Σ_{⟨ij⟩} s_i s_j Each bond counted once via right- and down-shifts. """ right = np.roll(grid, -1, axis=1) down = np.roll(grid, -1, axis=0) return float(-J * (grid * right + grid * down).sum() / grid.size) def connected_correlations( grid: np.ndarray, distances: tuple[int, ...] = (1, 2, 4, 8), ) -> np.ndarray: """Isotropic connected two-point function C(r) = ½[ + ] - ². Averaged over both spatial directions and all origin sites using periodic boundary conditions. """ m = float(grid.mean()) corr = [] for r in distances: cx = float((grid * np.roll(grid, r, axis=1)).mean()) cy = float((grid * np.roll(grid, r, axis=0)).mean()) corr.append((cx + cy) / 2.0 - m ** 2) return np.array(corr, dtype=np.float64) def cluster_stats(grid: np.ndarray) -> tuple[float, float]: """Mean and maximum cluster size for both spin species. Uses 4-connectivity (no diagonals) and open boundary conditions. Returns sizes normalised by the total number of spins so the result is independent of lattice size. Note: open BC means edge-spanning clusters are split at the boundary; this is applied consistently to both real and generated samples so systematic bias cancels in the Mahalanobis comparison. """ N = grid.size all_sizes: list[np.ndarray] = [] for spin in (1, -1): labeled, n_labels = scipy.ndimage.label(grid == spin) if n_labels > 0: # bincount index 0 is background; skip it all_sizes.append(np.bincount(labeled.ravel())[1:]) if not all_sizes: return 0.0, 0.0 sizes = np.concatenate(all_sizes).astype(np.float64) return float(sizes.mean()) / N, float(sizes.max()) / N def compute_features(grid: np.ndarray) -> np.ndarray: """Return the 11-D feature vector for a single ±1 grid of shape (L, L).""" m = float(grid.mean()) e = energy_per_spin(grid) cr = connected_correlations(grid) s_mean, s_max = cluster_stats(grid) return np.array( [m, m ** 2, abs(m), e, e ** 2, *cr, s_mean, s_max], dtype=np.float64, ) def compute_feature_matrix(grids: np.ndarray, desc: str = "features") -> np.ndarray: """Compute the (N, 11) feature matrix for a batch of grids.""" return np.stack( [compute_features(grids[i]) for i in tqdm(range(len(grids)), desc=desc, leave=False)] ) # --------------------------------------------------------------------------- # Ensemble statistics # --------------------------------------------------------------------------- def ensemble_stats(X: np.ndarray, T: float = T_C) -> dict[str, float]: """Derive thermodynamic ensemble statistics from a feature matrix. Arguments --------- X : (N, 11) feature matrix from ``compute_feature_matrix``. T : temperature used for χ and C_v normalisation. """ L = gen_config["lattice_size"] N = L * L m = X[:, FEATURE_NAMES.index("m")] m2 = X[:, FEATURE_NAMES.index("m^2")] m4 = m ** 4 e = X[:, FEATURE_NAMES.index("e")] chi = N * float(m.var()) / T Cv = N * float(e.var()) / T ** 2 binder = float(1.0 - m4.mean() / (3.0 * m2.mean() ** 2)) if m2.mean() > 0 else float("nan") return { "<|m|>": float(np.abs(m).mean()), "chi": chi, "C_v": Cv, "U4": binder, } # --------------------------------------------------------------------------- # Mahalanobis distance # --------------------------------------------------------------------------- def mahalanobis_distance( X_ref: np.ndarray, X_query: np.ndarray, reg: float = 1e-6, ) -> tuple[float, np.ndarray]: """Mahalanobis distance of the query-mean from the reference distribution. D = sqrt( Δμ^T Σ_ref^{-1} Δμ ) Also returns per-feature z-scores z_i = Δμ_i / σ_ref_i, where σ_ref_i = sqrt(Σ_ref[i,i]). |z_i| > 1 indicates a feature whose mean differs by more than one real-sample standard deviation. Parameters ---------- X_ref : (N, d) real / reference feature matrix X_query : (M, d) generated / query feature matrix reg : diagonal regularisation added to Σ_ref before inversion """ mu_ref = X_ref.mean(axis=0) mu_query = X_query.mean(axis=0) cov = np.cov(X_ref.T) + reg * np.eye(X_ref.shape[1]) cov_inv = np.linalg.inv(cov) delta = mu_query - mu_ref D = float(np.sqrt(max(0.0, delta @ cov_inv @ delta))) z_scores = delta / np.sqrt(np.diag(cov)) return D, z_scores # --------------------------------------------------------------------------- # Reporting # --------------------------------------------------------------------------- def print_feature_table(X_real: np.ndarray, X_gen: np.ndarray) -> None: mu_r = X_real.mean(axis=0) sd_r = X_real.std(axis=0) mu_g = X_gen.mean(axis=0) sd_g = X_gen.std(axis=0) col = 13 hdr = (f" {'Feature':<11} {'Real mean':>{col}} {'Real std':>{col}}" f" {'Gen mean':>{col}} {'Gen std':>{col}} {'z-score':>8}") print(hdr) print(" " + "─" * (len(hdr) - 2)) for name, mr, sr, mg, sg in zip(FEATURE_NAMES, mu_r, sd_r, mu_g, sd_g): z = (mg - mr) / (sr + 1e-12) flag = " <" if abs(z) > 1.0 else "" print(f" {name:<11} {mr:>{col}.4f} {sr:>{col}.4f}" f" {mg:>{col}.4f} {sg:>{col}.4f} {z:>+8.3f}{flag}") print() def print_ensemble_table(stats_real: dict, stats_gen: dict) -> None: labels = { "<|m|>": "mean |m|", "chi": "chi (susceptibility)", "C_v": "C_v (specific heat)", "U4": "U4 (Binder cumulant)", } print(f" {'Observable':<26} {'Real':>10} {'Generated':>10}") print(" " + "─" * 50) for key, label in labels.items(): r = stats_real[key] g = stats_gen[key] print(f" {label:<26} {r:>10.4f} {g:>10.4f}") print() # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- _SAMPLE_BATCH = 4 # fixed vmapped batch; changing triggers recompilation def generate_grids(model, n: int, key: jax.Array, L: int) -> np.ndarray: """Sample n grids in batches of _SAMPLE_BATCH with a progress bar. Using a fixed batch size means only one JIT compilation happens regardless of n. The final partial batch is padded then trimmed. """ batches = [] n_full, remainder = divmod(n, _SAMPLE_BATCH) n_batches = n_full + (1 if remainder else 0) with tqdm(total=n, unit="samples", desc="Sampling") as pbar: for i in range(n_batches): key, subkey = jax.random.split(key) tokens = np.asarray(sample_batch(model, _SAMPLE_BATCH, subkey)) batches.append(tokens) pbar.update(min(_SAMPLE_BATCH, n - i * _SAMPLE_BATCH)) return tokens_to_grids(np.concatenate(batches)[:n], L) def load_test_grids( test_data: Path | None, data: Path, n: int, L: int, rng: np.random.Generator, ) -> np.ndarray: """Load real test grids, preferring a dedicated test file over the val split. Parameters ---------- test_data : optional path to a standalone test .npy file (N, L, L) int8 {-1,+1} data : path to the main spins.npy (used only if test_data is None) """ if test_data is not None: spins = np.load(test_data) # (N, L, L) int8 tokens = (spins.astype(np.int32) + 1) // 2 # → {0, 1} rows, cols = snake_order(L) tokens = tokens[:, rows, cols] # (N, L²) else: _, tokens = load_ising_data(data) # val split of spins.npy n = min(n, len(tokens)) idx = rng.choice(len(tokens), size=n, replace=False) return tokens_to_grids(tokens[idx], L) # (n, L, L), values ±1 def parse_args(): p = argparse.ArgumentParser( description="Compare generated vs real Ising samples via physical observables." ) p.add_argument("--checkpoint", type=Path, required=True, help="Path to the .eqx checkpoint file.") p.add_argument("--data", type=Path, default=Path(__file__).parent / "spins.npy", help="Path to spins.npy (default: ./spins.npy). " "Used only if --test-data is not provided.") p.add_argument("--test-data", type=Path, default=Path(__file__).parent / "spins_test.npy", help="Dedicated held-out test set (.npy, N×L×L int8 {-1,+1}). " "Takes priority over the val split of --data.") p.add_argument("--num-samples", type=int, default=50, help="Number of samples to compare (default: 50).") p.add_argument("--samples-file", type=Path, default=None, help="Optional .npy of pre-generated {-1,+1} grids (N,L,L) " "from 'sample.py --output'. Skips generation entirely.") p.add_argument("--seed", type=int, default=0) return p.parse_args() def main(): args = parse_args() L = gen_config["lattice_size"] rng = np.random.default_rng(args.seed) # ── Real samples (test split) ───────────────────────────────────────────── # Prefer spins_test.npy; fall back to val split of spins.npy. test_path = args.test_data if (args.test_data and args.test_data.exists()) else None if test_path: print(f"Loading test data from {test_path} …") else: print("Loading test data from val split of spins.npy …") n = args.num_samples real_grids = load_test_grids(test_path, args.data, n, L, rng) n = len(real_grids) # may be capped by dataset size # ── Generated samples ───────────────────────────────────────────────────── if args.samples_file is not None: print(f"Loading pre-generated samples from {args.samples_file} …") gen_grids = np.load(args.samples_file).astype(np.int8)[:n] if gen_grids.shape[1:] != (L, L): raise ValueError( f"samples-file grid shape {gen_grids.shape[1:]} != ({L},{L})" ) n = min(n, len(gen_grids)) real_grids = real_grids[:n] else: print(f"Loading checkpoint from {args.checkpoint} …") model = load_checkpoint(args.checkpoint) key = jax.random.PRNGKey(args.seed) gen_grids = generate_grids(model, n, key, L) # (n, L, L), values ±1 print(f"\nL = {L} | N = {n} samples per group | T_C = {T_C:.6f}\n") # ── Feature matrices ────────────────────────────────────────────────────── X_real = compute_feature_matrix(real_grids, desc="Features: real ") X_gen = compute_feature_matrix(gen_grids, desc="Features: generated ") # ── Per-feature comparison table ────────────────────────────────────────── print("Per-feature statistics (z-score = Δμ / σ_real; '<' marks |z| > 1)\n") print_feature_table(X_real, X_gen) # ── Ensemble statistics ─────────────────────────────────────────────────── print("Ensemble statistics\n") print_ensemble_table(ensemble_stats(X_real), ensemble_stats(X_gen)) # ── Mahalanobis distance ────────────────────────────────────────────────── D, z = mahalanobis_distance(X_real, X_gen) print(f"Mahalanobis distance D = {D:.4f}") print( " (D measures how many 'std-devs' the generated feature mean sits") print( " from the real distribution in the decorrelated feature space.)") print() print(" Top deviating features:") order = np.argsort(np.abs(z))[::-1] for i in order[:5]: print(f" {FEATURE_NAMES[i]:<11} z = {z[i]:+.3f}") if __name__ == "__main__": main()