| """ |
| Signal reconstruction evaluation for Multi-Mechanism Normalizing Flow. |
| |
| For each test sample: |
| 1. Infer mechanism + parameter posterior |
| 2. Reconstruct signals from posterior mean, MAP, and random samples |
| 3. Compare reconstructed signals with observed signals |
| |
| This validates whether the inferred posteriors produce physically consistent |
| predictions, even when individual parameters have poor R² (due to compensation). |
| |
| Usage: |
| python evaluate_reconstruction.py --checkpoint outputs/multi_mechanism_multiscan/.../best.pt |
| python evaluate_reconstruction.py --checkpoint outputs/tpd_multiheat/.../best.pt --domain tpd |
| """ |
|
|
| import os |
| import sys |
| import json |
| import glob |
| import signal |
| import argparse |
| import time as time_module |
| from pathlib import Path |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| from tqdm import tqdm |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
|
|
| class _Timeout: |
| """Context manager that raises TimeoutError after `seconds` seconds. |
| |
| Uses POSIX signal.alarm in the main thread; falls back to a no-op in |
| worker threads (e.g. Gradio request handlers) where signals are unavailable. |
| """ |
| def __init__(self, seconds): |
| self.seconds = seconds |
| self._use_signal = False |
|
|
| def _handler(self, signum, frame): |
| raise TimeoutError(f"Reconstruction timed out after {self.seconds}s") |
|
|
| def __enter__(self): |
| import threading |
| if threading.current_thread() is threading.main_thread(): |
| self._use_signal = True |
| self._old = signal.signal(signal.SIGALRM, self._handler) |
| signal.alarm(self.seconds) |
| return self |
|
|
| def __exit__(self, *args): |
| if self._use_signal: |
| signal.alarm(0) |
| signal.signal(signal.SIGALRM, self._old) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Evaluate signal reconstruction") |
| parser.add_argument("--checkpoint", type=str, required=True) |
| parser.add_argument("--domain", type=str, default="ec", choices=["ec", "tpd"]) |
| parser.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) |
| parser.add_argument("--data_dir", type=str, default=None, |
| help="Override data directory (e.g. for clean test set)") |
| parser.add_argument("--max_samples", type=int, default=200) |
| parser.add_argument("--n_posterior_samples", type=int, default=100, |
| help="Posterior samples for reconstruction") |
| parser.add_argument("--n_recon_samples", type=int, default=10, |
| help="Number of random posterior samples to reconstruct per test sample") |
| parser.add_argument("--n_visualize", type=int, default=20) |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--temperature", type=float, default=1.0, |
| help="Base distribution temperature for sampling (>1 broadens posteriors)") |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def _safe_pow10(val): |
| """Compute 10**val, clamping to avoid OverflowError for extreme values.""" |
| val = np.clip(val, -300, 300) |
| return 10.0 ** val |
|
|
|
|
| def theta_to_ec_params(theta, mechanism, base_params): |
| """ |
| Convert model output (log-space) back to physical simulator parameters. |
| |
| Args: |
| theta: [D] numpy array of inferred parameters |
| mechanism: str mechanism name |
| base_params: dict of fixed parameters from the original sample |
| """ |
| from flow_model import MECHANISM_PARAMS |
| names = MECHANISM_PARAMS[mechanism]['names'] |
| p = dict(base_params) |
| p['kinetics'] = mechanism |
|
|
| for i, name in enumerate(names): |
| val = float(theta[i]) |
| if name.startswith('log10(') and name.endswith(')'): |
| phys_name = name[6:-1] |
| p[phys_name] = _safe_pow10(val) |
| elif name == 'alpha': |
| p['alpha'] = val |
| elif name == 'E0_offset': |
| pass |
|
|
| return p |
|
|
|
|
| def reconstruct_ec_signal(theta, mechanism, base_params, sigmas, n_spatial=64): |
| """ |
| Reconstruct CV signal(s) from inferred parameters. |
| |
| Args: |
| theta: [D] inferred parameters (model output space) |
| mechanism: str |
| base_params: dict with fixed params (theta_i, theta_v, dA, etc.) |
| sigmas: list of scan rates |
| n_spatial: spatial grid points |
| |
| Returns: |
| list of dicts with 'potential', 'flux', 'time', |
| 'c_ox_surface', 'c_red_surface' per scan rate |
| """ |
| import warnings |
| from generate_dataset_diffec import _run_single_cv |
|
|
| phys = theta_to_ec_params(theta, mechanism, base_params) |
| K0_at_1 = phys.get('K0', 1.0) |
| kc_at_1 = phys.get('kc', 1.0) |
|
|
| results = [] |
| for sigma in sigmas: |
| p = dict(phys) |
| p['sigma'] = float(sigma) |
| if mechanism in ('BV', 'MHC', 'EC', 'LH'): |
| p['K0'] = K0_at_1 / np.sqrt(sigma) |
| elif mechanism == 'Ads': |
| p['K0'] = K0_at_1 / sigma |
| if mechanism == 'EC': |
| p['kc'] = kc_at_1 / sigma |
|
|
| try: |
| with _Timeout(30), warnings.catch_warnings(): |
| warnings.simplefilter("ignore", RuntimeWarning) |
| result = _run_single_cv(p, n_spatial) |
| entry = { |
| 'potential': result['potential'], |
| 'flux': result['flux'], |
| 'time': result['time'], |
| 'success': True, |
| } |
| if 'c_ox' in result and 'c_red' in result: |
| entry['c_ox_surface'] = result['c_ox'][:, -1] |
| entry['c_red_surface'] = result['c_red'][:, 0] |
| results.append(entry) |
| except Exception as e: |
| results.append({'success': False, 'error': str(e)}) |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def theta_to_tpd_params(theta, mechanism, base_params): |
| """Convert model output back to physical TPD simulator parameters.""" |
| from generate_tpd_data import TPD_MECHANISM_PARAMS |
| names = TPD_MECHANISM_PARAMS[mechanism]['names'] |
| p = dict(base_params) |
| p['mechanism'] = mechanism |
|
|
| for i, name in enumerate(names): |
| val = float(theta[i]) |
| if name.startswith('log10(') and name.endswith(')'): |
| phys_name = name[6:-1] |
| p[phys_name] = _safe_pow10(val) |
| else: |
| p[name] = val |
|
|
| return p |
|
|
|
|
| def reconstruct_tpd_signal(theta, mechanism, base_params, betas): |
| """ |
| Reconstruct TPD signal(s) from inferred parameters. |
| |
| Args: |
| theta: [D] inferred parameters |
| mechanism: str |
| base_params: dict with T_start, T_end, etc. |
| betas: list of heating rates |
| |
| Returns: |
| list of dicts with 'temperature', 'rate', 'time' per heating rate |
| """ |
| import warnings |
| from generate_tpd_data import _run_single_tpd |
|
|
| phys = theta_to_tpd_params(theta, mechanism, base_params) |
|
|
| results = [] |
| for beta in betas: |
| p = dict(phys) |
| p['beta'] = float(beta) |
|
|
| try: |
| with _Timeout(30), warnings.catch_warnings(): |
| warnings.simplefilter("ignore", RuntimeWarning) |
| result = _run_single_tpd(p) |
| results.append({ |
| 'temperature': result['temperature'], |
| 'rate': result['rate'], |
| 'time': result['time'], |
| 'success': True, |
| }) |
| except Exception as e: |
| results.append({'success': False, 'error': str(e)}) |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def _valid_signal(arr): |
| """Check that a signal array contains no NaN/Inf/extreme values.""" |
| if not np.all(np.isfinite(arr)): |
| return False |
| if np.max(np.abs(arr)) > 1e30: |
| return False |
| return True |
|
|
|
|
| def signal_rmse(observed, reconstructed, length=None): |
| """RMSE between two signals, optionally truncated to valid length.""" |
| min_len = min(len(observed), len(reconstructed)) |
| if length is not None: |
| min_len = min(min_len, length) |
| o = observed[:min_len] |
| r = reconstructed[:min_len] |
| if not (_valid_signal(o) and _valid_signal(r)): |
| return float('nan') |
| return float(np.sqrt(np.mean((o - r) ** 2))) |
|
|
|
|
| def signal_nrmse(observed, reconstructed, length=None): |
| """Normalized RMSE (by peak-to-peak range of observed signal).""" |
| min_len = min(len(observed), len(reconstructed)) |
| if length is not None: |
| min_len = min(min_len, length) |
| o = observed[:min_len] |
| r = reconstructed[:min_len] |
| if not (_valid_signal(o) and _valid_signal(r)): |
| return float('nan') |
| ptp = np.ptp(o) |
| if ptp < 1e-20: |
| return float('inf') |
| return float(np.sqrt(np.mean((o - r) ** 2)) / ptp) |
|
|
|
|
| def signal_r2(observed, reconstructed, length=None): |
| """R² between observed and reconstructed signals.""" |
| min_len = min(len(observed), len(reconstructed)) |
| if length is not None: |
| min_len = min(min_len, length) |
| o = observed[:min_len] |
| r = reconstructed[:min_len] |
| if not (_valid_signal(o) and _valid_signal(r)): |
| return float('nan') |
| ss_res = np.sum((o - r) ** 2) |
| ss_tot = np.sum((o - np.mean(o)) ** 2) |
| if ss_tot < 1e-20: |
| return 0.0 |
| return float(1 - ss_res / ss_tot) |
|
|
|
|
| |
| |
| |
|
|
| def evaluate_ec(args): |
| from multi_mechanism_model import MultiMechanismFlow |
| from flow_model import MECHANISM_LIST, MECHANISM_PARAMS |
| from dataset import DiffECDataset, collate_fn |
| from torch.utils.data import DataLoader |
|
|
| ckpt_path = os.path.expanduser(args.checkpoint) |
| checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
| ckpt_args = checkpoint['args'] |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| model = MultiMechanismFlow( |
| d_context=ckpt_args.get('d_context', 128), |
| d_model=ckpt_args.get('d_model', 128), |
| n_coupling_layers=ckpt_args.get('n_coupling_layers', 6), |
| hidden_dim=ckpt_args.get('hidden_dim', 96), |
| coupling_type=ckpt_args.get('coupling_type', 'spline'), |
| n_bins=ckpt_args.get('n_bins', 8), |
| tail_bound=ckpt_args.get('tail_bound', 5.0), |
| ) |
|
|
| ckpt_dir = Path(ckpt_path).parent.parent |
| theta_stats_path = ckpt_dir / "theta_stats.json" |
| with open(theta_stats_path) as f: |
| theta_stats = json.load(f) |
| for mech in MECHANISM_LIST: |
| if mech in theta_stats: |
| model.set_theta_stats( |
| mech, |
| torch.tensor(theta_stats[mech]['mean']), |
| torch.tensor(theta_stats[mech]['std']), |
| ) |
|
|
| norm_stats_path = ckpt_dir / "norm_stats.json" |
| with open(norm_stats_path) as f: |
| norm_stats = json.load(f) |
|
|
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) |
| for m in model.modules(): |
| if hasattr(m, '_initialized') and not m.initialized: |
| m.initialized = True |
| model = model.to(device) |
| model.eval() |
|
|
| if args.data_dir: |
| data_dir = os.path.expanduser(args.data_dir) |
| else: |
| data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/DiffEC/data')) |
|
|
| |
| |
| |
| |
| split_dir = os.path.join(data_dir, args.split) |
| raw_per_mechanism = False |
| if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")): |
| |
| mech_dirs = [d for d in os.listdir(data_dir) |
| if os.path.isdir(os.path.join(data_dir, d, args.split))] |
| if mech_dirs: |
| raw_per_mechanism = True |
| print(f"Detected raw per-mechanism directory structure in {data_dir}") |
| print(f" Mechanisms found: {sorted(mech_dirs)}") |
|
|
| |
| import tempfile |
| tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_") |
| flat_dir = os.path.join(tmp_dir, args.split) |
| os.makedirs(flat_dir, exist_ok=True) |
| file_idx = 0 |
| for mech_name in sorted(mech_dirs): |
| mech_split = os.path.join(data_dir, mech_name, args.split) |
| for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))): |
| dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz") |
| os.symlink(os.path.abspath(f), dst) |
| file_idx += 1 |
| split_dir = flat_dir |
| print(f" Linked {file_idx} samples into temporary flat directory") |
|
|
| print(f"Loading data from: {split_dir}") |
| dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=True) |
| dataset.potential_mean = norm_stats['potential'][0] |
| dataset.potential_std = norm_stats['potential'][1] |
| dataset.flux_mean = norm_stats['flux'][0] |
| dataset.flux_std = norm_stats['flux'][1] |
| dataset.time_mean = norm_stats['time'][0] |
| dataset.time_std = norm_stats['time'][1] |
|
|
| raw_dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=False) |
|
|
| loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) |
|
|
| eval_dir = ckpt_dir / f"eval_recon_{args.split}" |
| if raw_per_mechanism: |
| eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}" |
| eval_dir.mkdir(exist_ok=True) |
|
|
| per_mech_nrmse_mean = defaultdict(list) |
| per_mech_nrmse_map = defaultdict(list) |
| per_mech_nrmse_samples = defaultdict(list) |
| per_mech_r2_mean = defaultdict(list) |
| per_mech_r2_map = defaultdict(list) |
| n_failed = 0 |
| vis_count = defaultdict(int) |
|
|
| print(f"Evaluating signal reconstruction on {len(dataset)} samples...") |
|
|
| for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")): |
| x = batch['input'].to(device) |
| scan_mask = batch['scan_mask'].to(device) |
| sigmas_tensor = batch['sigmas'].to(device) |
| flux_scales = batch['flux_scales'].to(device) |
| mech_id = batch['mechanism_id'].item() |
|
|
| if mech_id < 0 or mech_id >= len(MECHANISM_LIST): |
| continue |
| mech = MECHANISM_LIST[mech_id] |
|
|
| raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True) |
| raw_params = raw_data['params'].item() |
| raw_flux = raw_data['flux'].astype(np.float32) |
| raw_potential = raw_data['potential'].astype(np.float32) |
|
|
| if 'sigmas' in raw_data: |
| scan_rates = raw_data['sigmas'].astype(np.float64) |
| lengths = raw_data['lengths'].astype(int) |
| else: |
| scan_rates = np.array([raw_params.get('sigma', 1.0)]) |
| lengths = np.array([len(raw_potential)]) |
| raw_flux = raw_flux[np.newaxis, :] |
| raw_potential = raw_potential[np.newaxis, :] |
|
|
| with torch.no_grad(): |
| pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor, |
| flux_scales=flux_scales, |
| n_samples=args.n_posterior_samples, |
| temperature=args.temperature) |
|
|
| if pred['stats'][mech] is None: |
| continue |
|
|
| theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy() |
| samples = pred['samples'][mech][0].cpu().numpy() |
|
|
| |
| from scipy.stats import gaussian_kde |
| theta_map = np.zeros_like(theta_mean) |
| for d in range(len(theta_mean)): |
| s = samples[:, d] |
| if np.std(s) < 1e-10: |
| theta_map[d] = np.mean(s) |
| else: |
| try: |
| kde = gaussian_kde(s) |
| grid = np.linspace(s.min(), s.max(), 200) |
| theta_map[d] = grid[np.argmax(kde(grid))] |
| except Exception: |
| theta_map[d] = np.median(s) |
|
|
| base_params = dict(raw_params) |
|
|
| |
| recon_mean = reconstruct_ec_signal(theta_mean, mech, base_params, scan_rates) |
| |
| recon_map = reconstruct_ec_signal(theta_map, mech, base_params, scan_rates) |
|
|
| |
| nrmse_mean_list = [] |
| nrmse_map_list = [] |
| r2_mean_list = [] |
| r2_map_list = [] |
|
|
| for s_idx in range(len(scan_rates)): |
| obs_flux = raw_flux[s_idx] |
| length = lengths[s_idx] |
|
|
| if recon_mean[s_idx]['success']: |
| v = signal_nrmse(obs_flux, recon_mean[s_idx]['flux'], length) |
| r = signal_r2(obs_flux, recon_mean[s_idx]['flux'], length) |
| if np.isfinite(v): |
| nrmse_mean_list.append(v) |
| if np.isfinite(r): |
| r2_mean_list.append(r) |
| if recon_map[s_idx]['success']: |
| v = signal_nrmse(obs_flux, recon_map[s_idx]['flux'], length) |
| r = signal_r2(obs_flux, recon_map[s_idx]['flux'], length) |
| if np.isfinite(v): |
| nrmse_map_list.append(v) |
| if np.isfinite(r): |
| r2_map_list.append(r) |
|
|
| if nrmse_mean_list: |
| per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list)) |
| if r2_mean_list: |
| per_mech_r2_mean[mech].append(np.mean(r2_mean_list)) |
| if nrmse_map_list: |
| per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list)) |
| if r2_map_list: |
| per_mech_r2_map[mech].append(np.mean(r2_map_list)) |
|
|
| |
| sample_nrmses = [] |
| n_recon = min(args.n_recon_samples, samples.shape[0]) |
| sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False) |
| for si in sample_indices: |
| recon_s = reconstruct_ec_signal(samples[si], mech, base_params, scan_rates) |
| nrmses = [] |
| for s_idx in range(len(scan_rates)): |
| if recon_s[s_idx]['success']: |
| v = signal_nrmse(raw_flux[s_idx], recon_s[s_idx]['flux'], lengths[s_idx]) |
| if np.isfinite(v): |
| nrmses.append(v) |
| if nrmses: |
| sample_nrmses.append(np.mean(nrmses)) |
| if sample_nrmses: |
| per_mech_nrmse_samples[mech].append(np.median(sample_nrmses)) |
|
|
| |
| if vis_count[mech] < args.n_visualize and recon_mean[0]['success']: |
| fig, axes = plt.subplots(1, len(scan_rates), figsize=(5 * len(scan_rates), 4)) |
| if len(scan_rates) == 1: |
| axes = [axes] |
| for s_idx, ax in enumerate(axes): |
| length = lengths[s_idx] |
| obs_pot = raw_potential[s_idx, :length] |
| obs_flux_s = raw_flux[s_idx, :length] |
| ax.plot(obs_pot, obs_flux_s, 'k-', lw=1.5, label='Observed', alpha=0.8) |
|
|
| if recon_mean[s_idx]['success']: |
| r_pot = recon_mean[s_idx]['potential'] |
| r_flux = recon_mean[s_idx]['flux'] |
| min_len = min(length, len(r_pot)) |
| nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length) |
| lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)' |
| ax.plot(r_pot[:min_len], r_flux[:min_len], 'r--', lw=1.2, label=lbl) |
|
|
| if recon_map[s_idx]['success']: |
| r_pot = recon_map[s_idx]['potential'] |
| r_flux = recon_map[s_idx]['flux'] |
| min_len = min(length, len(r_pot)) |
| nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length) |
| lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)' |
| ax.plot(r_pot[:min_len], r_flux[:min_len], 'b:', lw=1.2, label=lbl) |
|
|
| |
| for si in sample_indices[:3]: |
| recon_s = reconstruct_ec_signal(samples[si], mech, base_params, [scan_rates[s_idx]]) |
| if recon_s[0]['success']: |
| r_pot = recon_s[0]['potential'] |
| r_flux = recon_s[0]['flux'] |
| min_len = min(length, len(r_pot)) |
| ax.plot(r_pot[:min_len], r_flux[:min_len], '-', lw=0.5, |
| alpha=0.3, color='gray') |
|
|
| ax.set_xlabel('Potential (θ)') |
| ax.set_ylabel('Flux') |
| ax.set_title(f'σ={scan_rates[s_idx]:.2f}') |
| ax.legend(fontsize=7) |
|
|
| fig.suptitle(f'{mech} sample {idx}', fontsize=12) |
| plt.tight_layout() |
| plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150) |
| plt.close(fig) |
| vis_count[mech] += 1 |
|
|
| |
| print("\n" + "=" * 60) |
| print("SIGNAL RECONSTRUCTION METRICS") |
| print("=" * 60) |
|
|
| results = {} |
| for mech in MECHANISM_LIST: |
| if not per_mech_nrmse_mean[mech]: |
| continue |
|
|
| nrmse_mean = np.array(per_mech_nrmse_mean[mech]) |
| nrmse_map = np.array(per_mech_nrmse_map[mech]) |
| nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([]) |
| r2_mean = np.array(per_mech_r2_mean[mech]) |
| r2_map = np.array(per_mech_r2_map[mech]) |
|
|
| results[mech] = { |
| 'n_samples': len(nrmse_mean), |
| 'nrmse_mean': {'median': float(np.median(nrmse_mean)), |
| 'mean': float(np.mean(nrmse_mean)), |
| 'std': float(np.std(nrmse_mean)), |
| 'q25': float(np.percentile(nrmse_mean, 25)), |
| 'q75': float(np.percentile(nrmse_mean, 75))}, |
| 'nrmse_map': {'median': float(np.median(nrmse_map)), |
| 'mean': float(np.mean(nrmse_map)), |
| 'std': float(np.std(nrmse_map))}, |
| 'r2_signal_mean': {'median': float(np.median(r2_mean)), |
| 'mean': float(np.mean(r2_mean))}, |
| 'r2_signal_map': {'median': float(np.median(r2_map)), |
| 'mean': float(np.mean(r2_map))}, |
| } |
| if len(nrmse_samp) > 0: |
| results[mech]['nrmse_posterior_median'] = { |
| 'median': float(np.median(nrmse_samp)), |
| 'mean': float(np.mean(nrmse_samp)), |
| } |
|
|
| print(f"\n{mech} ({len(nrmse_mean)} samples):") |
| print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} " |
| f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}") |
| print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} " |
| f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}") |
| if len(nrmse_samp) > 0: |
| print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} " |
| f"mean={np.mean(nrmse_samp):.4f}") |
| print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}") |
| print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}") |
|
|
| with open(eval_dir / "reconstruction_results.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| if raw_per_mechanism: |
| import shutil |
| shutil.rmtree(tmp_dir, ignore_errors=True) |
|
|
| print(f"\nResults saved to {eval_dir}") |
| print(f"Visualizations: {sum(vis_count.values())} plots saved") |
|
|
|
|
| |
| |
| |
|
|
| def evaluate_tpd(args): |
| from tpd_model import MultiMechanismFlowTPD |
| from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS |
| from dataset_tpd import TPDDataset, collate_fn |
| from torch.utils.data import DataLoader |
|
|
| ckpt_path = os.path.expanduser(args.checkpoint) |
| checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) |
| ckpt_args = checkpoint['args'] |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| model = MultiMechanismFlowTPD( |
| d_context=ckpt_args.get('d_context', 128), |
| d_model=ckpt_args.get('d_model', 128), |
| n_coupling_layers=ckpt_args.get('n_coupling_layers', 6), |
| hidden_dim=ckpt_args.get('hidden_dim', 96), |
| coupling_type=ckpt_args.get('coupling_type', 'spline'), |
| n_bins=ckpt_args.get('n_bins', 8), |
| tail_bound=ckpt_args.get('tail_bound', 5.0), |
| ) |
|
|
| ckpt_dir = Path(ckpt_path).parent.parent |
| theta_stats_path = ckpt_dir / "theta_stats.json" |
| with open(theta_stats_path) as f: |
| theta_stats = json.load(f) |
| for mech in TPD_MECHANISM_LIST: |
| if mech in theta_stats: |
| model.set_theta_stats( |
| mech, |
| torch.tensor(theta_stats[mech]['mean']), |
| torch.tensor(theta_stats[mech]['std']), |
| ) |
|
|
| norm_stats_path = ckpt_dir / "norm_stats.json" |
| with open(norm_stats_path) as f: |
| norm_stats = json.load(f) |
|
|
| model.load_state_dict(checkpoint['model_state_dict'], strict=False) |
| for m in model.modules(): |
| if hasattr(m, '_initialized') and not m.initialized: |
| m.initialized = True |
| model = model.to(device) |
| model.eval() |
|
|
| if args.data_dir: |
| data_dir = os.path.expanduser(args.data_dir) |
| else: |
| data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/ECFlow/data_tpd_multiheat')) |
|
|
| split_dir = os.path.join(data_dir, args.split) |
| raw_per_mechanism = False |
| if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")): |
| mech_dirs = [d for d in os.listdir(data_dir) |
| if os.path.isdir(os.path.join(data_dir, d, args.split))] |
| if mech_dirs: |
| raw_per_mechanism = True |
| print(f"Detected raw per-mechanism directory structure in {data_dir}") |
| print(f" Mechanisms found: {sorted(mech_dirs)}") |
|
|
| import tempfile |
| tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_tpd_") |
| flat_dir = os.path.join(tmp_dir, args.split) |
| os.makedirs(flat_dir, exist_ok=True) |
| file_idx = 0 |
| for mech_name in sorted(mech_dirs): |
| mech_split = os.path.join(data_dir, mech_name, args.split) |
| for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))): |
| dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz") |
| os.symlink(os.path.abspath(f), dst) |
| file_idx += 1 |
| split_dir = flat_dir |
| print(f" Linked {file_idx} samples into temporary flat directory") |
|
|
| print(f"Loading data from: {split_dir}") |
| dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=True) |
| dataset.temperature_mean = norm_stats['temperature'][0] |
| dataset.temperature_std = norm_stats['temperature'][1] |
| dataset.rate_mean = norm_stats['rate'][0] |
| dataset.rate_std = norm_stats['rate'][1] |
|
|
| raw_dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=False) |
|
|
| loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) |
|
|
| eval_dir = ckpt_dir / f"eval_recon_{args.split}" |
| if raw_per_mechanism: |
| eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}" |
| eval_dir.mkdir(exist_ok=True) |
|
|
| per_mech_nrmse_mean = defaultdict(list) |
| per_mech_nrmse_map = defaultdict(list) |
| per_mech_nrmse_samples = defaultdict(list) |
| per_mech_r2_mean = defaultdict(list) |
| per_mech_r2_map = defaultdict(list) |
| vis_count = defaultdict(int) |
|
|
| print(f"Evaluating TPD signal reconstruction on {len(dataset)} samples...") |
|
|
| for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")): |
| x = batch['input'].to(device) |
| scan_mask = batch['scan_mask'].to(device) |
| sigmas_tensor = batch['sigmas'].to(device) |
| flux_scales = batch['flux_scales'].to(device) |
| mech_id = batch['mechanism_id'].item() |
|
|
| if mech_id < 0 or mech_id >= len(TPD_MECHANISM_LIST): |
| continue |
| mech = TPD_MECHANISM_LIST[mech_id] |
|
|
| raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True) |
| raw_params = raw_data['params'].item() |
| raw_rate = raw_data['rate'].astype(np.float32) |
| raw_temp = raw_data['temperature'].astype(np.float32) |
|
|
| if 'heating_rates' in raw_data: |
| betas = raw_data['heating_rates'].astype(np.float64) |
| lengths = raw_data['lengths'].astype(int) |
| else: |
| betas = np.array([raw_params.get('beta', 1.0)]) |
| lengths = np.array([len(raw_rate)]) |
| raw_rate = raw_rate[np.newaxis, :] |
| raw_temp = raw_temp[np.newaxis, :] |
|
|
| with torch.no_grad(): |
| pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor, |
| flux_scales=flux_scales, |
| n_samples=args.n_posterior_samples, |
| temperature=args.temperature) |
|
|
| if pred['stats'][mech] is None: |
| continue |
|
|
| theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy() |
| samples = pred['samples'][mech][0].cpu().numpy() |
|
|
| from scipy.stats import gaussian_kde |
| theta_map = np.zeros_like(theta_mean) |
| for d in range(len(theta_mean)): |
| s = samples[:, d] |
| if np.std(s) < 1e-10: |
| theta_map[d] = np.mean(s) |
| else: |
| try: |
| kde = gaussian_kde(s) |
| grid = np.linspace(s.min(), s.max(), 200) |
| theta_map[d] = grid[np.argmax(kde(grid))] |
| except Exception: |
| theta_map[d] = np.median(s) |
|
|
| base_params = dict(raw_params) |
|
|
| recon_mean = reconstruct_tpd_signal(theta_mean, mech, base_params, betas) |
| recon_map = reconstruct_tpd_signal(theta_map, mech, base_params, betas) |
|
|
| nrmse_mean_list = [] |
| nrmse_map_list = [] |
| r2_mean_list = [] |
| r2_map_list = [] |
|
|
| for s_idx in range(len(betas)): |
| obs_rate = raw_rate[s_idx] |
| length = lengths[s_idx] |
|
|
| if recon_mean[s_idx]['success']: |
| v = signal_nrmse(obs_rate, recon_mean[s_idx]['rate'], length) |
| r = signal_r2(obs_rate, recon_mean[s_idx]['rate'], length) |
| if np.isfinite(v): |
| nrmse_mean_list.append(v) |
| if np.isfinite(r): |
| r2_mean_list.append(r) |
| if recon_map[s_idx]['success']: |
| v = signal_nrmse(obs_rate, recon_map[s_idx]['rate'], length) |
| r = signal_r2(obs_rate, recon_map[s_idx]['rate'], length) |
| if np.isfinite(v): |
| nrmse_map_list.append(v) |
| if np.isfinite(r): |
| r2_map_list.append(r) |
|
|
| if nrmse_mean_list: |
| per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list)) |
| if r2_mean_list: |
| per_mech_r2_mean[mech].append(np.mean(r2_mean_list)) |
| if nrmse_map_list: |
| per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list)) |
| if r2_map_list: |
| per_mech_r2_map[mech].append(np.mean(r2_map_list)) |
|
|
| sample_nrmses = [] |
| n_recon = min(args.n_recon_samples, samples.shape[0]) |
| sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False) |
| for si in sample_indices: |
| recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, betas) |
| nrmses = [] |
| for s_idx in range(len(betas)): |
| if recon_s[s_idx]['success']: |
| v = signal_nrmse(raw_rate[s_idx], recon_s[s_idx]['rate'], lengths[s_idx]) |
| if np.isfinite(v): |
| nrmses.append(v) |
| if nrmses: |
| sample_nrmses.append(np.mean(nrmses)) |
| if sample_nrmses: |
| per_mech_nrmse_samples[mech].append(np.median(sample_nrmses)) |
|
|
| |
| if vis_count[mech] < args.n_visualize and recon_mean[0]['success']: |
| fig, axes = plt.subplots(1, len(betas), figsize=(5 * len(betas), 4)) |
| if len(betas) == 1: |
| axes = [axes] |
| for s_idx, ax in enumerate(axes): |
| length = lengths[s_idx] |
| obs_temp = raw_temp[s_idx, :length] |
| obs_rate_s = raw_rate[s_idx, :length] |
| ax.plot(obs_temp, obs_rate_s, 'k-', lw=1.5, label='Observed', alpha=0.8) |
|
|
| if recon_mean[s_idx]['success']: |
| r_temp = recon_mean[s_idx]['temperature'] |
| r_rate = recon_mean[s_idx]['rate'] |
| min_len = min(length, len(r_temp)) |
| nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length) |
| lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)' |
| ax.plot(r_temp[:min_len], r_rate[:min_len], 'r--', lw=1.2, label=lbl) |
|
|
| if recon_map[s_idx]['success']: |
| r_temp = recon_map[s_idx]['temperature'] |
| r_rate = recon_map[s_idx]['rate'] |
| min_len = min(length, len(r_temp)) |
| nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length) |
| lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)' |
| ax.plot(r_temp[:min_len], r_rate[:min_len], 'b:', lw=1.2, label=lbl) |
|
|
| for si in sample_indices[:3]: |
| recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, [betas[s_idx]]) |
| if recon_s[0]['success']: |
| r_temp = recon_s[0]['temperature'] |
| r_rate = recon_s[0]['rate'] |
| min_len = min(length, len(r_temp)) |
| ax.plot(r_temp[:min_len], r_rate[:min_len], '-', lw=0.5, |
| alpha=0.3, color='gray') |
|
|
| ax.set_xlabel('Temperature (K)') |
| ax.set_ylabel('Rate') |
| ax.set_title(f'β={betas[s_idx]:.2f} K/s') |
| ax.legend(fontsize=7) |
|
|
| fig.suptitle(f'{mech} sample {idx}', fontsize=12) |
| plt.tight_layout() |
| plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150) |
| plt.close(fig) |
| vis_count[mech] += 1 |
|
|
| print("\n" + "=" * 60) |
| print("SIGNAL RECONSTRUCTION METRICS (TPD)") |
| print("=" * 60) |
|
|
| results = {} |
| for mech in TPD_MECHANISM_LIST: |
| if not per_mech_nrmse_mean[mech]: |
| continue |
|
|
| nrmse_mean = np.array(per_mech_nrmse_mean[mech]) |
| nrmse_map = np.array(per_mech_nrmse_map[mech]) |
| nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([]) |
| r2_mean = np.array(per_mech_r2_mean[mech]) |
| r2_map = np.array(per_mech_r2_map[mech]) |
|
|
| results[mech] = { |
| 'n_samples': len(nrmse_mean), |
| 'nrmse_mean': {'median': float(np.median(nrmse_mean)), |
| 'mean': float(np.mean(nrmse_mean)), |
| 'std': float(np.std(nrmse_mean))}, |
| 'nrmse_map': {'median': float(np.median(nrmse_map)), |
| 'mean': float(np.mean(nrmse_map)), |
| 'std': float(np.std(nrmse_map))}, |
| 'r2_signal_mean': {'median': float(np.median(r2_mean)), |
| 'mean': float(np.mean(r2_mean))}, |
| 'r2_signal_map': {'median': float(np.median(r2_map)), |
| 'mean': float(np.mean(r2_map))}, |
| } |
| if len(nrmse_samp) > 0: |
| results[mech]['nrmse_posterior_median'] = { |
| 'median': float(np.median(nrmse_samp)), |
| 'mean': float(np.mean(nrmse_samp)), |
| } |
|
|
| print(f"\n{mech} ({len(nrmse_mean)} samples):") |
| print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} " |
| f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}") |
| print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} " |
| f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}") |
| if len(nrmse_samp) > 0: |
| print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} " |
| f"mean={np.mean(nrmse_samp):.4f}") |
| print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}") |
| print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}") |
|
|
| with open(eval_dir / "reconstruction_results.json", "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| if raw_per_mechanism: |
| import shutil |
| shutil.rmtree(tmp_dir, ignore_errors=True) |
|
|
| print(f"\nResults saved to {eval_dir}") |
| print(f"Visualizations: {sum(vis_count.values())} plots saved") |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| if args.domain == "ec": |
| evaluate_ec(args) |
| else: |
| evaluate_tpd(args) |
|
|