""" Signal reconstruction evaluation for Multi-Mechanism Normalizing Flow. For each test sample: 1. Infer mechanism + parameter posterior 2. Reconstruct signals from posterior mean, MAP, and random samples 3. Compare reconstructed signals with observed signals This validates whether the inferred posteriors produce physically consistent predictions, even when individual parameters have poor R² (due to compensation). Usage: python evaluate_reconstruction.py --checkpoint outputs/multi_mechanism_multiscan/.../best.pt python evaluate_reconstruction.py --checkpoint outputs/tpd_multiheat/.../best.pt --domain tpd """ import os import sys import json import glob import signal import argparse import time as time_module from pathlib import Path from collections import defaultdict import numpy as np import torch from tqdm import tqdm import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt class _Timeout: """Context manager that raises TimeoutError after `seconds` seconds. Uses POSIX signal.alarm in the main thread; falls back to a no-op in worker threads (e.g. Gradio request handlers) where signals are unavailable. """ def __init__(self, seconds): self.seconds = seconds self._use_signal = False def _handler(self, signum, frame): raise TimeoutError(f"Reconstruction timed out after {self.seconds}s") def __enter__(self): import threading if threading.current_thread() is threading.main_thread(): self._use_signal = True self._old = signal.signal(signal.SIGALRM, self._handler) signal.alarm(self.seconds) return self def __exit__(self, *args): if self._use_signal: signal.alarm(0) signal.signal(signal.SIGALRM, self._old) def parse_args(): parser = argparse.ArgumentParser(description="Evaluate signal reconstruction") parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--domain", type=str, default="ec", choices=["ec", "tpd"]) parser.add_argument("--split", type=str, default="test", choices=["train", "val", "test"]) parser.add_argument("--data_dir", type=str, default=None, help="Override data directory (e.g. for clean test set)") parser.add_argument("--max_samples", type=int, default=200) parser.add_argument("--n_posterior_samples", type=int, default=100, help="Posterior samples for reconstruction") parser.add_argument("--n_recon_samples", type=int, default=10, help="Number of random posterior samples to reconstruct per test sample") parser.add_argument("--n_visualize", type=int, default=20) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--temperature", type=float, default=1.0, help="Base distribution temperature for sampling (>1 broadens posteriors)") return parser.parse_args() # ============================================================================= # EC reconstruction # ============================================================================= def _safe_pow10(val): """Compute 10**val, clamping to avoid OverflowError for extreme values.""" val = np.clip(val, -300, 300) return 10.0 ** val def theta_to_ec_params(theta, mechanism, base_params): """ Convert model output (log-space) back to physical simulator parameters. Args: theta: [D] numpy array of inferred parameters mechanism: str mechanism name base_params: dict of fixed parameters from the original sample """ from flow_model import MECHANISM_PARAMS names = MECHANISM_PARAMS[mechanism]['names'] p = dict(base_params) p['kinetics'] = mechanism for i, name in enumerate(names): val = float(theta[i]) if name.startswith('log10(') and name.endswith(')'): phys_name = name[6:-1] p[phys_name] = _safe_pow10(val) elif name == 'alpha': p['alpha'] = val elif name == 'E0_offset': pass return p def reconstruct_ec_signal(theta, mechanism, base_params, sigmas, n_spatial=64): """ Reconstruct CV signal(s) from inferred parameters. Args: theta: [D] inferred parameters (model output space) mechanism: str base_params: dict with fixed params (theta_i, theta_v, dA, etc.) sigmas: list of scan rates n_spatial: spatial grid points Returns: list of dicts with 'potential', 'flux', 'time', 'c_ox_surface', 'c_red_surface' per scan rate """ import warnings from generate_dataset_diffec import _run_single_cv phys = theta_to_ec_params(theta, mechanism, base_params) K0_at_1 = phys.get('K0', 1.0) kc_at_1 = phys.get('kc', 1.0) results = [] for sigma in sigmas: p = dict(phys) p['sigma'] = float(sigma) if mechanism in ('BV', 'MHC', 'EC', 'LH'): p['K0'] = K0_at_1 / np.sqrt(sigma) elif mechanism == 'Ads': p['K0'] = K0_at_1 / sigma if mechanism == 'EC': p['kc'] = kc_at_1 / sigma try: with _Timeout(30), warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) result = _run_single_cv(p, n_spatial) entry = { 'potential': result['potential'], 'flux': result['flux'], 'time': result['time'], 'success': True, } if 'c_ox' in result and 'c_red' in result: entry['c_ox_surface'] = result['c_ox'][:, -1] entry['c_red_surface'] = result['c_red'][:, 0] results.append(entry) except Exception as e: results.append({'success': False, 'error': str(e)}) return results # ============================================================================= # TPD reconstruction # ============================================================================= def theta_to_tpd_params(theta, mechanism, base_params): """Convert model output back to physical TPD simulator parameters.""" from generate_tpd_data import TPD_MECHANISM_PARAMS names = TPD_MECHANISM_PARAMS[mechanism]['names'] p = dict(base_params) p['mechanism'] = mechanism for i, name in enumerate(names): val = float(theta[i]) if name.startswith('log10(') and name.endswith(')'): phys_name = name[6:-1] p[phys_name] = _safe_pow10(val) else: p[name] = val return p def reconstruct_tpd_signal(theta, mechanism, base_params, betas): """ Reconstruct TPD signal(s) from inferred parameters. Args: theta: [D] inferred parameters mechanism: str base_params: dict with T_start, T_end, etc. betas: list of heating rates Returns: list of dicts with 'temperature', 'rate', 'time' per heating rate """ import warnings from generate_tpd_data import _run_single_tpd phys = theta_to_tpd_params(theta, mechanism, base_params) results = [] for beta in betas: p = dict(phys) p['beta'] = float(beta) try: with _Timeout(30), warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) result = _run_single_tpd(p) results.append({ 'temperature': result['temperature'], 'rate': result['rate'], 'time': result['time'], 'success': True, }) except Exception as e: results.append({'success': False, 'error': str(e)}) return results # ============================================================================= # Metrics # ============================================================================= def _valid_signal(arr): """Check that a signal array contains no NaN/Inf/extreme values.""" if not np.all(np.isfinite(arr)): return False if np.max(np.abs(arr)) > 1e30: return False return True def signal_rmse(observed, reconstructed, length=None): """RMSE between two signals, optionally truncated to valid length.""" min_len = min(len(observed), len(reconstructed)) if length is not None: min_len = min(min_len, length) o = observed[:min_len] r = reconstructed[:min_len] if not (_valid_signal(o) and _valid_signal(r)): return float('nan') return float(np.sqrt(np.mean((o - r) ** 2))) def signal_nrmse(observed, reconstructed, length=None): """Normalized RMSE (by peak-to-peak range of observed signal).""" min_len = min(len(observed), len(reconstructed)) if length is not None: min_len = min(min_len, length) o = observed[:min_len] r = reconstructed[:min_len] if not (_valid_signal(o) and _valid_signal(r)): return float('nan') ptp = np.ptp(o) if ptp < 1e-20: return float('inf') return float(np.sqrt(np.mean((o - r) ** 2)) / ptp) def signal_r2(observed, reconstructed, length=None): """R² between observed and reconstructed signals.""" min_len = min(len(observed), len(reconstructed)) if length is not None: min_len = min(min_len, length) o = observed[:min_len] r = reconstructed[:min_len] if not (_valid_signal(o) and _valid_signal(r)): return float('nan') ss_res = np.sum((o - r) ** 2) ss_tot = np.sum((o - np.mean(o)) ** 2) if ss_tot < 1e-20: return 0.0 return float(1 - ss_res / ss_tot) # ============================================================================= # EC evaluation # ============================================================================= def evaluate_ec(args): from multi_mechanism_model import MultiMechanismFlow from flow_model import MECHANISM_LIST, MECHANISM_PARAMS from dataset import DiffECDataset, collate_fn from torch.utils.data import DataLoader ckpt_path = os.path.expanduser(args.checkpoint) checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) ckpt_args = checkpoint['args'] device = 'cuda' if torch.cuda.is_available() else 'cpu' model = MultiMechanismFlow( d_context=ckpt_args.get('d_context', 128), d_model=ckpt_args.get('d_model', 128), n_coupling_layers=ckpt_args.get('n_coupling_layers', 6), hidden_dim=ckpt_args.get('hidden_dim', 96), coupling_type=ckpt_args.get('coupling_type', 'spline'), n_bins=ckpt_args.get('n_bins', 8), tail_bound=ckpt_args.get('tail_bound', 5.0), ) ckpt_dir = Path(ckpt_path).parent.parent theta_stats_path = ckpt_dir / "theta_stats.json" with open(theta_stats_path) as f: theta_stats = json.load(f) for mech in MECHANISM_LIST: if mech in theta_stats: model.set_theta_stats( mech, torch.tensor(theta_stats[mech]['mean']), torch.tensor(theta_stats[mech]['std']), ) norm_stats_path = ckpt_dir / "norm_stats.json" with open(norm_stats_path) as f: norm_stats = json.load(f) model.load_state_dict(checkpoint['model_state_dict'], strict=False) for m in model.modules(): if hasattr(m, '_initialized') and not m.initialized: m.initialized = True model = model.to(device) model.eval() if args.data_dir: data_dir = os.path.expanduser(args.data_dir) else: data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/DiffEC/data')) # Support raw per-mechanism directory structure: # {data_dir}/{Mechanism}/{split}/sample_*.npz # as well as the assembled flat structure: # {data_dir}/{split}/sample_*.npz split_dir = os.path.join(data_dir, args.split) raw_per_mechanism = False if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")): # Try raw per-mechanism structure mech_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d, args.split))] if mech_dirs: raw_per_mechanism = True print(f"Detected raw per-mechanism directory structure in {data_dir}") print(f" Mechanisms found: {sorted(mech_dirs)}") # Create a temporary flat directory with symlinks import tempfile tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_") flat_dir = os.path.join(tmp_dir, args.split) os.makedirs(flat_dir, exist_ok=True) file_idx = 0 for mech_name in sorted(mech_dirs): mech_split = os.path.join(data_dir, mech_name, args.split) for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))): dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz") os.symlink(os.path.abspath(f), dst) file_idx += 1 split_dir = flat_dir print(f" Linked {file_idx} samples into temporary flat directory") print(f"Loading data from: {split_dir}") dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=True) dataset.potential_mean = norm_stats['potential'][0] dataset.potential_std = norm_stats['potential'][1] dataset.flux_mean = norm_stats['flux'][0] dataset.flux_std = norm_stats['flux'][1] dataset.time_mean = norm_stats['time'][0] dataset.time_std = norm_stats['time'][1] raw_dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=False) loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) eval_dir = ckpt_dir / f"eval_recon_{args.split}" if raw_per_mechanism: eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}" eval_dir.mkdir(exist_ok=True) per_mech_nrmse_mean = defaultdict(list) per_mech_nrmse_map = defaultdict(list) per_mech_nrmse_samples = defaultdict(list) per_mech_r2_mean = defaultdict(list) per_mech_r2_map = defaultdict(list) n_failed = 0 vis_count = defaultdict(int) print(f"Evaluating signal reconstruction on {len(dataset)} samples...") for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")): x = batch['input'].to(device) scan_mask = batch['scan_mask'].to(device) sigmas_tensor = batch['sigmas'].to(device) flux_scales = batch['flux_scales'].to(device) mech_id = batch['mechanism_id'].item() if mech_id < 0 or mech_id >= len(MECHANISM_LIST): continue mech = MECHANISM_LIST[mech_id] raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True) raw_params = raw_data['params'].item() raw_flux = raw_data['flux'].astype(np.float32) raw_potential = raw_data['potential'].astype(np.float32) if 'sigmas' in raw_data: scan_rates = raw_data['sigmas'].astype(np.float64) lengths = raw_data['lengths'].astype(int) else: scan_rates = np.array([raw_params.get('sigma', 1.0)]) lengths = np.array([len(raw_potential)]) raw_flux = raw_flux[np.newaxis, :] raw_potential = raw_potential[np.newaxis, :] with torch.no_grad(): pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor, flux_scales=flux_scales, n_samples=args.n_posterior_samples, temperature=args.temperature) if pred['stats'][mech] is None: continue theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy() samples = pred['samples'][mech][0].cpu().numpy() # MAP estimate via 1D KDE from scipy.stats import gaussian_kde theta_map = np.zeros_like(theta_mean) for d in range(len(theta_mean)): s = samples[:, d] if np.std(s) < 1e-10: theta_map[d] = np.mean(s) else: try: kde = gaussian_kde(s) grid = np.linspace(s.min(), s.max(), 200) theta_map[d] = grid[np.argmax(kde(grid))] except Exception: theta_map[d] = np.median(s) base_params = dict(raw_params) # Reconstruct from mean recon_mean = reconstruct_ec_signal(theta_mean, mech, base_params, scan_rates) # Reconstruct from MAP recon_map = reconstruct_ec_signal(theta_map, mech, base_params, scan_rates) # Compute metrics per scan rate nrmse_mean_list = [] nrmse_map_list = [] r2_mean_list = [] r2_map_list = [] for s_idx in range(len(scan_rates)): obs_flux = raw_flux[s_idx] length = lengths[s_idx] if recon_mean[s_idx]['success']: v = signal_nrmse(obs_flux, recon_mean[s_idx]['flux'], length) r = signal_r2(obs_flux, recon_mean[s_idx]['flux'], length) if np.isfinite(v): nrmse_mean_list.append(v) if np.isfinite(r): r2_mean_list.append(r) if recon_map[s_idx]['success']: v = signal_nrmse(obs_flux, recon_map[s_idx]['flux'], length) r = signal_r2(obs_flux, recon_map[s_idx]['flux'], length) if np.isfinite(v): nrmse_map_list.append(v) if np.isfinite(r): r2_map_list.append(r) if nrmse_mean_list: per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list)) if r2_mean_list: per_mech_r2_mean[mech].append(np.mean(r2_mean_list)) if nrmse_map_list: per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list)) if r2_map_list: per_mech_r2_map[mech].append(np.mean(r2_map_list)) # Reconstruct from random posterior samples sample_nrmses = [] n_recon = min(args.n_recon_samples, samples.shape[0]) sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False) for si in sample_indices: recon_s = reconstruct_ec_signal(samples[si], mech, base_params, scan_rates) nrmses = [] for s_idx in range(len(scan_rates)): if recon_s[s_idx]['success']: v = signal_nrmse(raw_flux[s_idx], recon_s[s_idx]['flux'], lengths[s_idx]) if np.isfinite(v): nrmses.append(v) if nrmses: sample_nrmses.append(np.mean(nrmses)) if sample_nrmses: per_mech_nrmse_samples[mech].append(np.median(sample_nrmses)) # Visualization if vis_count[mech] < args.n_visualize and recon_mean[0]['success']: fig, axes = plt.subplots(1, len(scan_rates), figsize=(5 * len(scan_rates), 4)) if len(scan_rates) == 1: axes = [axes] for s_idx, ax in enumerate(axes): length = lengths[s_idx] obs_pot = raw_potential[s_idx, :length] obs_flux_s = raw_flux[s_idx, :length] ax.plot(obs_pot, obs_flux_s, 'k-', lw=1.5, label='Observed', alpha=0.8) if recon_mean[s_idx]['success']: r_pot = recon_mean[s_idx]['potential'] r_flux = recon_mean[s_idx]['flux'] min_len = min(length, len(r_pot)) nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length) lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)' ax.plot(r_pot[:min_len], r_flux[:min_len], 'r--', lw=1.2, label=lbl) if recon_map[s_idx]['success']: r_pot = recon_map[s_idx]['potential'] r_flux = recon_map[s_idx]['flux'] min_len = min(length, len(r_pot)) nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length) lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)' ax.plot(r_pot[:min_len], r_flux[:min_len], 'b:', lw=1.2, label=lbl) # Plot a few posterior samples for si in sample_indices[:3]: recon_s = reconstruct_ec_signal(samples[si], mech, base_params, [scan_rates[s_idx]]) if recon_s[0]['success']: r_pot = recon_s[0]['potential'] r_flux = recon_s[0]['flux'] min_len = min(length, len(r_pot)) ax.plot(r_pot[:min_len], r_flux[:min_len], '-', lw=0.5, alpha=0.3, color='gray') ax.set_xlabel('Potential (θ)') ax.set_ylabel('Flux') ax.set_title(f'σ={scan_rates[s_idx]:.2f}') ax.legend(fontsize=7) fig.suptitle(f'{mech} sample {idx}', fontsize=12) plt.tight_layout() plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150) plt.close(fig) vis_count[mech] += 1 # Print and save results print("\n" + "=" * 60) print("SIGNAL RECONSTRUCTION METRICS") print("=" * 60) results = {} for mech in MECHANISM_LIST: if not per_mech_nrmse_mean[mech]: continue nrmse_mean = np.array(per_mech_nrmse_mean[mech]) nrmse_map = np.array(per_mech_nrmse_map[mech]) nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([]) r2_mean = np.array(per_mech_r2_mean[mech]) r2_map = np.array(per_mech_r2_map[mech]) results[mech] = { 'n_samples': len(nrmse_mean), 'nrmse_mean': {'median': float(np.median(nrmse_mean)), 'mean': float(np.mean(nrmse_mean)), 'std': float(np.std(nrmse_mean)), 'q25': float(np.percentile(nrmse_mean, 25)), 'q75': float(np.percentile(nrmse_mean, 75))}, 'nrmse_map': {'median': float(np.median(nrmse_map)), 'mean': float(np.mean(nrmse_map)), 'std': float(np.std(nrmse_map))}, 'r2_signal_mean': {'median': float(np.median(r2_mean)), 'mean': float(np.mean(r2_mean))}, 'r2_signal_map': {'median': float(np.median(r2_map)), 'mean': float(np.mean(r2_map))}, } if len(nrmse_samp) > 0: results[mech]['nrmse_posterior_median'] = { 'median': float(np.median(nrmse_samp)), 'mean': float(np.mean(nrmse_samp)), } print(f"\n{mech} ({len(nrmse_mean)} samples):") print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} " f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}") print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} " f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}") if len(nrmse_samp) > 0: print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} " f"mean={np.mean(nrmse_samp):.4f}") print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}") print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}") with open(eval_dir / "reconstruction_results.json", "w") as f: json.dump(results, f, indent=2) if raw_per_mechanism: import shutil shutil.rmtree(tmp_dir, ignore_errors=True) print(f"\nResults saved to {eval_dir}") print(f"Visualizations: {sum(vis_count.values())} plots saved") # ============================================================================= # TPD evaluation # ============================================================================= def evaluate_tpd(args): from tpd_model import MultiMechanismFlowTPD from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS from dataset_tpd import TPDDataset, collate_fn from torch.utils.data import DataLoader ckpt_path = os.path.expanduser(args.checkpoint) checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False) ckpt_args = checkpoint['args'] device = 'cuda' if torch.cuda.is_available() else 'cpu' model = MultiMechanismFlowTPD( d_context=ckpt_args.get('d_context', 128), d_model=ckpt_args.get('d_model', 128), n_coupling_layers=ckpt_args.get('n_coupling_layers', 6), hidden_dim=ckpt_args.get('hidden_dim', 96), coupling_type=ckpt_args.get('coupling_type', 'spline'), n_bins=ckpt_args.get('n_bins', 8), tail_bound=ckpt_args.get('tail_bound', 5.0), ) ckpt_dir = Path(ckpt_path).parent.parent theta_stats_path = ckpt_dir / "theta_stats.json" with open(theta_stats_path) as f: theta_stats = json.load(f) for mech in TPD_MECHANISM_LIST: if mech in theta_stats: model.set_theta_stats( mech, torch.tensor(theta_stats[mech]['mean']), torch.tensor(theta_stats[mech]['std']), ) norm_stats_path = ckpt_dir / "norm_stats.json" with open(norm_stats_path) as f: norm_stats = json.load(f) model.load_state_dict(checkpoint['model_state_dict'], strict=False) for m in model.modules(): if hasattr(m, '_initialized') and not m.initialized: m.initialized = True model = model.to(device) model.eval() if args.data_dir: data_dir = os.path.expanduser(args.data_dir) else: data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/ECFlow/data_tpd_multiheat')) split_dir = os.path.join(data_dir, args.split) raw_per_mechanism = False if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")): mech_dirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d, args.split))] if mech_dirs: raw_per_mechanism = True print(f"Detected raw per-mechanism directory structure in {data_dir}") print(f" Mechanisms found: {sorted(mech_dirs)}") import tempfile tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_tpd_") flat_dir = os.path.join(tmp_dir, args.split) os.makedirs(flat_dir, exist_ok=True) file_idx = 0 for mech_name in sorted(mech_dirs): mech_split = os.path.join(data_dir, mech_name, args.split) for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))): dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz") os.symlink(os.path.abspath(f), dst) file_idx += 1 split_dir = flat_dir print(f" Linked {file_idx} samples into temporary flat directory") print(f"Loading data from: {split_dir}") dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=True) dataset.temperature_mean = norm_stats['temperature'][0] dataset.temperature_std = norm_stats['temperature'][1] dataset.rate_mean = norm_stats['rate'][0] dataset.rate_std = norm_stats['rate'][1] raw_dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=False) loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) eval_dir = ckpt_dir / f"eval_recon_{args.split}" if raw_per_mechanism: eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}" eval_dir.mkdir(exist_ok=True) per_mech_nrmse_mean = defaultdict(list) per_mech_nrmse_map = defaultdict(list) per_mech_nrmse_samples = defaultdict(list) per_mech_r2_mean = defaultdict(list) per_mech_r2_map = defaultdict(list) vis_count = defaultdict(int) print(f"Evaluating TPD signal reconstruction on {len(dataset)} samples...") for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")): x = batch['input'].to(device) scan_mask = batch['scan_mask'].to(device) sigmas_tensor = batch['sigmas'].to(device) flux_scales = batch['flux_scales'].to(device) mech_id = batch['mechanism_id'].item() if mech_id < 0 or mech_id >= len(TPD_MECHANISM_LIST): continue mech = TPD_MECHANISM_LIST[mech_id] raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True) raw_params = raw_data['params'].item() raw_rate = raw_data['rate'].astype(np.float32) raw_temp = raw_data['temperature'].astype(np.float32) if 'heating_rates' in raw_data: betas = raw_data['heating_rates'].astype(np.float64) lengths = raw_data['lengths'].astype(int) else: betas = np.array([raw_params.get('beta', 1.0)]) lengths = np.array([len(raw_rate)]) raw_rate = raw_rate[np.newaxis, :] raw_temp = raw_temp[np.newaxis, :] with torch.no_grad(): pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor, flux_scales=flux_scales, n_samples=args.n_posterior_samples, temperature=args.temperature) if pred['stats'][mech] is None: continue theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy() samples = pred['samples'][mech][0].cpu().numpy() from scipy.stats import gaussian_kde theta_map = np.zeros_like(theta_mean) for d in range(len(theta_mean)): s = samples[:, d] if np.std(s) < 1e-10: theta_map[d] = np.mean(s) else: try: kde = gaussian_kde(s) grid = np.linspace(s.min(), s.max(), 200) theta_map[d] = grid[np.argmax(kde(grid))] except Exception: theta_map[d] = np.median(s) base_params = dict(raw_params) recon_mean = reconstruct_tpd_signal(theta_mean, mech, base_params, betas) recon_map = reconstruct_tpd_signal(theta_map, mech, base_params, betas) nrmse_mean_list = [] nrmse_map_list = [] r2_mean_list = [] r2_map_list = [] for s_idx in range(len(betas)): obs_rate = raw_rate[s_idx] length = lengths[s_idx] if recon_mean[s_idx]['success']: v = signal_nrmse(obs_rate, recon_mean[s_idx]['rate'], length) r = signal_r2(obs_rate, recon_mean[s_idx]['rate'], length) if np.isfinite(v): nrmse_mean_list.append(v) if np.isfinite(r): r2_mean_list.append(r) if recon_map[s_idx]['success']: v = signal_nrmse(obs_rate, recon_map[s_idx]['rate'], length) r = signal_r2(obs_rate, recon_map[s_idx]['rate'], length) if np.isfinite(v): nrmse_map_list.append(v) if np.isfinite(r): r2_map_list.append(r) if nrmse_mean_list: per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list)) if r2_mean_list: per_mech_r2_mean[mech].append(np.mean(r2_mean_list)) if nrmse_map_list: per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list)) if r2_map_list: per_mech_r2_map[mech].append(np.mean(r2_map_list)) sample_nrmses = [] n_recon = min(args.n_recon_samples, samples.shape[0]) sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False) for si in sample_indices: recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, betas) nrmses = [] for s_idx in range(len(betas)): if recon_s[s_idx]['success']: v = signal_nrmse(raw_rate[s_idx], recon_s[s_idx]['rate'], lengths[s_idx]) if np.isfinite(v): nrmses.append(v) if nrmses: sample_nrmses.append(np.mean(nrmses)) if sample_nrmses: per_mech_nrmse_samples[mech].append(np.median(sample_nrmses)) # Visualization if vis_count[mech] < args.n_visualize and recon_mean[0]['success']: fig, axes = plt.subplots(1, len(betas), figsize=(5 * len(betas), 4)) if len(betas) == 1: axes = [axes] for s_idx, ax in enumerate(axes): length = lengths[s_idx] obs_temp = raw_temp[s_idx, :length] obs_rate_s = raw_rate[s_idx, :length] ax.plot(obs_temp, obs_rate_s, 'k-', lw=1.5, label='Observed', alpha=0.8) if recon_mean[s_idx]['success']: r_temp = recon_mean[s_idx]['temperature'] r_rate = recon_mean[s_idx]['rate'] min_len = min(length, len(r_temp)) nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length) lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)' ax.plot(r_temp[:min_len], r_rate[:min_len], 'r--', lw=1.2, label=lbl) if recon_map[s_idx]['success']: r_temp = recon_map[s_idx]['temperature'] r_rate = recon_map[s_idx]['rate'] min_len = min(length, len(r_temp)) nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length) lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)' ax.plot(r_temp[:min_len], r_rate[:min_len], 'b:', lw=1.2, label=lbl) for si in sample_indices[:3]: recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, [betas[s_idx]]) if recon_s[0]['success']: r_temp = recon_s[0]['temperature'] r_rate = recon_s[0]['rate'] min_len = min(length, len(r_temp)) ax.plot(r_temp[:min_len], r_rate[:min_len], '-', lw=0.5, alpha=0.3, color='gray') ax.set_xlabel('Temperature (K)') ax.set_ylabel('Rate') ax.set_title(f'β={betas[s_idx]:.2f} K/s') ax.legend(fontsize=7) fig.suptitle(f'{mech} sample {idx}', fontsize=12) plt.tight_layout() plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150) plt.close(fig) vis_count[mech] += 1 print("\n" + "=" * 60) print("SIGNAL RECONSTRUCTION METRICS (TPD)") print("=" * 60) results = {} for mech in TPD_MECHANISM_LIST: if not per_mech_nrmse_mean[mech]: continue nrmse_mean = np.array(per_mech_nrmse_mean[mech]) nrmse_map = np.array(per_mech_nrmse_map[mech]) nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([]) r2_mean = np.array(per_mech_r2_mean[mech]) r2_map = np.array(per_mech_r2_map[mech]) results[mech] = { 'n_samples': len(nrmse_mean), 'nrmse_mean': {'median': float(np.median(nrmse_mean)), 'mean': float(np.mean(nrmse_mean)), 'std': float(np.std(nrmse_mean))}, 'nrmse_map': {'median': float(np.median(nrmse_map)), 'mean': float(np.mean(nrmse_map)), 'std': float(np.std(nrmse_map))}, 'r2_signal_mean': {'median': float(np.median(r2_mean)), 'mean': float(np.mean(r2_mean))}, 'r2_signal_map': {'median': float(np.median(r2_map)), 'mean': float(np.mean(r2_map))}, } if len(nrmse_samp) > 0: results[mech]['nrmse_posterior_median'] = { 'median': float(np.median(nrmse_samp)), 'mean': float(np.mean(nrmse_samp)), } print(f"\n{mech} ({len(nrmse_mean)} samples):") print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} " f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}") print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} " f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}") if len(nrmse_samp) > 0: print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} " f"mean={np.mean(nrmse_samp):.4f}") print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}") print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}") with open(eval_dir / "reconstruction_results.json", "w") as f: json.dump(results, f, indent=2) if raw_per_mechanism: import shutil shutil.rmtree(tmp_dir, ignore_errors=True) print(f"\nResults saved to {eval_dir}") print(f"Visualizations: {sum(vis_count.values())} plots saved") if __name__ == "__main__": args = parse_args() if args.domain == "ec": evaluate_ec(args) else: evaluate_tpd(args)