ecflow / evaluate_reconstruction.py
Bing Yan
Initial ECFlow deployment
d6b782e
"""
Signal reconstruction evaluation for Multi-Mechanism Normalizing Flow.
For each test sample:
1. Infer mechanism + parameter posterior
2. Reconstruct signals from posterior mean, MAP, and random samples
3. Compare reconstructed signals with observed signals
This validates whether the inferred posteriors produce physically consistent
predictions, even when individual parameters have poor R² (due to compensation).
Usage:
python evaluate_reconstruction.py --checkpoint outputs/multi_mechanism_multiscan/.../best.pt
python evaluate_reconstruction.py --checkpoint outputs/tpd_multiheat/.../best.pt --domain tpd
"""
import os
import sys
import json
import glob
import signal
import argparse
import time as time_module
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class _Timeout:
"""Context manager that raises TimeoutError after `seconds` seconds.
Uses POSIX signal.alarm in the main thread; falls back to a no-op in
worker threads (e.g. Gradio request handlers) where signals are unavailable.
"""
def __init__(self, seconds):
self.seconds = seconds
self._use_signal = False
def _handler(self, signum, frame):
raise TimeoutError(f"Reconstruction timed out after {self.seconds}s")
def __enter__(self):
import threading
if threading.current_thread() is threading.main_thread():
self._use_signal = True
self._old = signal.signal(signal.SIGALRM, self._handler)
signal.alarm(self.seconds)
return self
def __exit__(self, *args):
if self._use_signal:
signal.alarm(0)
signal.signal(signal.SIGALRM, self._old)
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate signal reconstruction")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--domain", type=str, default="ec", choices=["ec", "tpd"])
parser.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
parser.add_argument("--data_dir", type=str, default=None,
help="Override data directory (e.g. for clean test set)")
parser.add_argument("--max_samples", type=int, default=200)
parser.add_argument("--n_posterior_samples", type=int, default=100,
help="Posterior samples for reconstruction")
parser.add_argument("--n_recon_samples", type=int, default=10,
help="Number of random posterior samples to reconstruct per test sample")
parser.add_argument("--n_visualize", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--temperature", type=float, default=1.0,
help="Base distribution temperature for sampling (>1 broadens posteriors)")
return parser.parse_args()
# =============================================================================
# EC reconstruction
# =============================================================================
def _safe_pow10(val):
"""Compute 10**val, clamping to avoid OverflowError for extreme values."""
val = np.clip(val, -300, 300)
return 10.0 ** val
def theta_to_ec_params(theta, mechanism, base_params):
"""
Convert model output (log-space) back to physical simulator parameters.
Args:
theta: [D] numpy array of inferred parameters
mechanism: str mechanism name
base_params: dict of fixed parameters from the original sample
"""
from flow_model import MECHANISM_PARAMS
names = MECHANISM_PARAMS[mechanism]['names']
p = dict(base_params)
p['kinetics'] = mechanism
for i, name in enumerate(names):
val = float(theta[i])
if name.startswith('log10(') and name.endswith(')'):
phys_name = name[6:-1]
p[phys_name] = _safe_pow10(val)
elif name == 'alpha':
p['alpha'] = val
elif name == 'E0_offset':
pass
return p
def reconstruct_ec_signal(theta, mechanism, base_params, sigmas, n_spatial=64):
"""
Reconstruct CV signal(s) from inferred parameters.
Args:
theta: [D] inferred parameters (model output space)
mechanism: str
base_params: dict with fixed params (theta_i, theta_v, dA, etc.)
sigmas: list of scan rates
n_spatial: spatial grid points
Returns:
list of dicts with 'potential', 'flux', 'time',
'c_ox_surface', 'c_red_surface' per scan rate
"""
import warnings
from generate_dataset_diffec import _run_single_cv
phys = theta_to_ec_params(theta, mechanism, base_params)
K0_at_1 = phys.get('K0', 1.0)
kc_at_1 = phys.get('kc', 1.0)
results = []
for sigma in sigmas:
p = dict(phys)
p['sigma'] = float(sigma)
if mechanism in ('BV', 'MHC', 'EC', 'LH'):
p['K0'] = K0_at_1 / np.sqrt(sigma)
elif mechanism == 'Ads':
p['K0'] = K0_at_1 / sigma
if mechanism == 'EC':
p['kc'] = kc_at_1 / sigma
try:
with _Timeout(30), warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
result = _run_single_cv(p, n_spatial)
entry = {
'potential': result['potential'],
'flux': result['flux'],
'time': result['time'],
'success': True,
}
if 'c_ox' in result and 'c_red' in result:
entry['c_ox_surface'] = result['c_ox'][:, -1]
entry['c_red_surface'] = result['c_red'][:, 0]
results.append(entry)
except Exception as e:
results.append({'success': False, 'error': str(e)})
return results
# =============================================================================
# TPD reconstruction
# =============================================================================
def theta_to_tpd_params(theta, mechanism, base_params):
"""Convert model output back to physical TPD simulator parameters."""
from generate_tpd_data import TPD_MECHANISM_PARAMS
names = TPD_MECHANISM_PARAMS[mechanism]['names']
p = dict(base_params)
p['mechanism'] = mechanism
for i, name in enumerate(names):
val = float(theta[i])
if name.startswith('log10(') and name.endswith(')'):
phys_name = name[6:-1]
p[phys_name] = _safe_pow10(val)
else:
p[name] = val
return p
def reconstruct_tpd_signal(theta, mechanism, base_params, betas):
"""
Reconstruct TPD signal(s) from inferred parameters.
Args:
theta: [D] inferred parameters
mechanism: str
base_params: dict with T_start, T_end, etc.
betas: list of heating rates
Returns:
list of dicts with 'temperature', 'rate', 'time' per heating rate
"""
import warnings
from generate_tpd_data import _run_single_tpd
phys = theta_to_tpd_params(theta, mechanism, base_params)
results = []
for beta in betas:
p = dict(phys)
p['beta'] = float(beta)
try:
with _Timeout(30), warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
result = _run_single_tpd(p)
results.append({
'temperature': result['temperature'],
'rate': result['rate'],
'time': result['time'],
'success': True,
})
except Exception as e:
results.append({'success': False, 'error': str(e)})
return results
# =============================================================================
# Metrics
# =============================================================================
def _valid_signal(arr):
"""Check that a signal array contains no NaN/Inf/extreme values."""
if not np.all(np.isfinite(arr)):
return False
if np.max(np.abs(arr)) > 1e30:
return False
return True
def signal_rmse(observed, reconstructed, length=None):
"""RMSE between two signals, optionally truncated to valid length."""
min_len = min(len(observed), len(reconstructed))
if length is not None:
min_len = min(min_len, length)
o = observed[:min_len]
r = reconstructed[:min_len]
if not (_valid_signal(o) and _valid_signal(r)):
return float('nan')
return float(np.sqrt(np.mean((o - r) ** 2)))
def signal_nrmse(observed, reconstructed, length=None):
"""Normalized RMSE (by peak-to-peak range of observed signal)."""
min_len = min(len(observed), len(reconstructed))
if length is not None:
min_len = min(min_len, length)
o = observed[:min_len]
r = reconstructed[:min_len]
if not (_valid_signal(o) and _valid_signal(r)):
return float('nan')
ptp = np.ptp(o)
if ptp < 1e-20:
return float('inf')
return float(np.sqrt(np.mean((o - r) ** 2)) / ptp)
def signal_r2(observed, reconstructed, length=None):
"""R² between observed and reconstructed signals."""
min_len = min(len(observed), len(reconstructed))
if length is not None:
min_len = min(min_len, length)
o = observed[:min_len]
r = reconstructed[:min_len]
if not (_valid_signal(o) and _valid_signal(r)):
return float('nan')
ss_res = np.sum((o - r) ** 2)
ss_tot = np.sum((o - np.mean(o)) ** 2)
if ss_tot < 1e-20:
return 0.0
return float(1 - ss_res / ss_tot)
# =============================================================================
# EC evaluation
# =============================================================================
def evaluate_ec(args):
from multi_mechanism_model import MultiMechanismFlow
from flow_model import MECHANISM_LIST, MECHANISM_PARAMS
from dataset import DiffECDataset, collate_fn
from torch.utils.data import DataLoader
ckpt_path = os.path.expanduser(args.checkpoint)
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
ckpt_args = checkpoint['args']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MultiMechanismFlow(
d_context=ckpt_args.get('d_context', 128),
d_model=ckpt_args.get('d_model', 128),
n_coupling_layers=ckpt_args.get('n_coupling_layers', 6),
hidden_dim=ckpt_args.get('hidden_dim', 96),
coupling_type=ckpt_args.get('coupling_type', 'spline'),
n_bins=ckpt_args.get('n_bins', 8),
tail_bound=ckpt_args.get('tail_bound', 5.0),
)
ckpt_dir = Path(ckpt_path).parent.parent
theta_stats_path = ckpt_dir / "theta_stats.json"
with open(theta_stats_path) as f:
theta_stats = json.load(f)
for mech in MECHANISM_LIST:
if mech in theta_stats:
model.set_theta_stats(
mech,
torch.tensor(theta_stats[mech]['mean']),
torch.tensor(theta_stats[mech]['std']),
)
norm_stats_path = ckpt_dir / "norm_stats.json"
with open(norm_stats_path) as f:
norm_stats = json.load(f)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
for m in model.modules():
if hasattr(m, '_initialized') and not m.initialized:
m.initialized = True
model = model.to(device)
model.eval()
if args.data_dir:
data_dir = os.path.expanduser(args.data_dir)
else:
data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/DiffEC/data'))
# Support raw per-mechanism directory structure:
# {data_dir}/{Mechanism}/{split}/sample_*.npz
# as well as the assembled flat structure:
# {data_dir}/{split}/sample_*.npz
split_dir = os.path.join(data_dir, args.split)
raw_per_mechanism = False
if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")):
# Try raw per-mechanism structure
mech_dirs = [d for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d, args.split))]
if mech_dirs:
raw_per_mechanism = True
print(f"Detected raw per-mechanism directory structure in {data_dir}")
print(f" Mechanisms found: {sorted(mech_dirs)}")
# Create a temporary flat directory with symlinks
import tempfile
tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_")
flat_dir = os.path.join(tmp_dir, args.split)
os.makedirs(flat_dir, exist_ok=True)
file_idx = 0
for mech_name in sorted(mech_dirs):
mech_split = os.path.join(data_dir, mech_name, args.split)
for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))):
dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz")
os.symlink(os.path.abspath(f), dst)
file_idx += 1
split_dir = flat_dir
print(f" Linked {file_idx} samples into temporary flat directory")
print(f"Loading data from: {split_dir}")
dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=True)
dataset.potential_mean = norm_stats['potential'][0]
dataset.potential_std = norm_stats['potential'][1]
dataset.flux_mean = norm_stats['flux'][0]
dataset.flux_std = norm_stats['flux'][1]
dataset.time_mean = norm_stats['time'][0]
dataset.time_std = norm_stats['time'][1]
raw_dataset = DiffECDataset(split_dir, max_samples=args.max_samples, normalize_input=False)
loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
eval_dir = ckpt_dir / f"eval_recon_{args.split}"
if raw_per_mechanism:
eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}"
eval_dir.mkdir(exist_ok=True)
per_mech_nrmse_mean = defaultdict(list)
per_mech_nrmse_map = defaultdict(list)
per_mech_nrmse_samples = defaultdict(list)
per_mech_r2_mean = defaultdict(list)
per_mech_r2_map = defaultdict(list)
n_failed = 0
vis_count = defaultdict(int)
print(f"Evaluating signal reconstruction on {len(dataset)} samples...")
for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")):
x = batch['input'].to(device)
scan_mask = batch['scan_mask'].to(device)
sigmas_tensor = batch['sigmas'].to(device)
flux_scales = batch['flux_scales'].to(device)
mech_id = batch['mechanism_id'].item()
if mech_id < 0 or mech_id >= len(MECHANISM_LIST):
continue
mech = MECHANISM_LIST[mech_id]
raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True)
raw_params = raw_data['params'].item()
raw_flux = raw_data['flux'].astype(np.float32)
raw_potential = raw_data['potential'].astype(np.float32)
if 'sigmas' in raw_data:
scan_rates = raw_data['sigmas'].astype(np.float64)
lengths = raw_data['lengths'].astype(int)
else:
scan_rates = np.array([raw_params.get('sigma', 1.0)])
lengths = np.array([len(raw_potential)])
raw_flux = raw_flux[np.newaxis, :]
raw_potential = raw_potential[np.newaxis, :]
with torch.no_grad():
pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor,
flux_scales=flux_scales,
n_samples=args.n_posterior_samples,
temperature=args.temperature)
if pred['stats'][mech] is None:
continue
theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy()
samples = pred['samples'][mech][0].cpu().numpy()
# MAP estimate via 1D KDE
from scipy.stats import gaussian_kde
theta_map = np.zeros_like(theta_mean)
for d in range(len(theta_mean)):
s = samples[:, d]
if np.std(s) < 1e-10:
theta_map[d] = np.mean(s)
else:
try:
kde = gaussian_kde(s)
grid = np.linspace(s.min(), s.max(), 200)
theta_map[d] = grid[np.argmax(kde(grid))]
except Exception:
theta_map[d] = np.median(s)
base_params = dict(raw_params)
# Reconstruct from mean
recon_mean = reconstruct_ec_signal(theta_mean, mech, base_params, scan_rates)
# Reconstruct from MAP
recon_map = reconstruct_ec_signal(theta_map, mech, base_params, scan_rates)
# Compute metrics per scan rate
nrmse_mean_list = []
nrmse_map_list = []
r2_mean_list = []
r2_map_list = []
for s_idx in range(len(scan_rates)):
obs_flux = raw_flux[s_idx]
length = lengths[s_idx]
if recon_mean[s_idx]['success']:
v = signal_nrmse(obs_flux, recon_mean[s_idx]['flux'], length)
r = signal_r2(obs_flux, recon_mean[s_idx]['flux'], length)
if np.isfinite(v):
nrmse_mean_list.append(v)
if np.isfinite(r):
r2_mean_list.append(r)
if recon_map[s_idx]['success']:
v = signal_nrmse(obs_flux, recon_map[s_idx]['flux'], length)
r = signal_r2(obs_flux, recon_map[s_idx]['flux'], length)
if np.isfinite(v):
nrmse_map_list.append(v)
if np.isfinite(r):
r2_map_list.append(r)
if nrmse_mean_list:
per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list))
if r2_mean_list:
per_mech_r2_mean[mech].append(np.mean(r2_mean_list))
if nrmse_map_list:
per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list))
if r2_map_list:
per_mech_r2_map[mech].append(np.mean(r2_map_list))
# Reconstruct from random posterior samples
sample_nrmses = []
n_recon = min(args.n_recon_samples, samples.shape[0])
sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False)
for si in sample_indices:
recon_s = reconstruct_ec_signal(samples[si], mech, base_params, scan_rates)
nrmses = []
for s_idx in range(len(scan_rates)):
if recon_s[s_idx]['success']:
v = signal_nrmse(raw_flux[s_idx], recon_s[s_idx]['flux'], lengths[s_idx])
if np.isfinite(v):
nrmses.append(v)
if nrmses:
sample_nrmses.append(np.mean(nrmses))
if sample_nrmses:
per_mech_nrmse_samples[mech].append(np.median(sample_nrmses))
# Visualization
if vis_count[mech] < args.n_visualize and recon_mean[0]['success']:
fig, axes = plt.subplots(1, len(scan_rates), figsize=(5 * len(scan_rates), 4))
if len(scan_rates) == 1:
axes = [axes]
for s_idx, ax in enumerate(axes):
length = lengths[s_idx]
obs_pot = raw_potential[s_idx, :length]
obs_flux_s = raw_flux[s_idx, :length]
ax.plot(obs_pot, obs_flux_s, 'k-', lw=1.5, label='Observed', alpha=0.8)
if recon_mean[s_idx]['success']:
r_pot = recon_mean[s_idx]['potential']
r_flux = recon_mean[s_idx]['flux']
min_len = min(length, len(r_pot))
nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length)
lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)'
ax.plot(r_pot[:min_len], r_flux[:min_len], 'r--', lw=1.2, label=lbl)
if recon_map[s_idx]['success']:
r_pot = recon_map[s_idx]['potential']
r_flux = recon_map[s_idx]['flux']
min_len = min(length, len(r_pot))
nrmse_val = signal_nrmse(obs_flux_s, r_flux[:length] if length <= len(r_flux) else r_flux, length)
lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)'
ax.plot(r_pot[:min_len], r_flux[:min_len], 'b:', lw=1.2, label=lbl)
# Plot a few posterior samples
for si in sample_indices[:3]:
recon_s = reconstruct_ec_signal(samples[si], mech, base_params, [scan_rates[s_idx]])
if recon_s[0]['success']:
r_pot = recon_s[0]['potential']
r_flux = recon_s[0]['flux']
min_len = min(length, len(r_pot))
ax.plot(r_pot[:min_len], r_flux[:min_len], '-', lw=0.5,
alpha=0.3, color='gray')
ax.set_xlabel('Potential (θ)')
ax.set_ylabel('Flux')
ax.set_title(f'σ={scan_rates[s_idx]:.2f}')
ax.legend(fontsize=7)
fig.suptitle(f'{mech} sample {idx}', fontsize=12)
plt.tight_layout()
plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150)
plt.close(fig)
vis_count[mech] += 1
# Print and save results
print("\n" + "=" * 60)
print("SIGNAL RECONSTRUCTION METRICS")
print("=" * 60)
results = {}
for mech in MECHANISM_LIST:
if not per_mech_nrmse_mean[mech]:
continue
nrmse_mean = np.array(per_mech_nrmse_mean[mech])
nrmse_map = np.array(per_mech_nrmse_map[mech])
nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([])
r2_mean = np.array(per_mech_r2_mean[mech])
r2_map = np.array(per_mech_r2_map[mech])
results[mech] = {
'n_samples': len(nrmse_mean),
'nrmse_mean': {'median': float(np.median(nrmse_mean)),
'mean': float(np.mean(nrmse_mean)),
'std': float(np.std(nrmse_mean)),
'q25': float(np.percentile(nrmse_mean, 25)),
'q75': float(np.percentile(nrmse_mean, 75))},
'nrmse_map': {'median': float(np.median(nrmse_map)),
'mean': float(np.mean(nrmse_map)),
'std': float(np.std(nrmse_map))},
'r2_signal_mean': {'median': float(np.median(r2_mean)),
'mean': float(np.mean(r2_mean))},
'r2_signal_map': {'median': float(np.median(r2_map)),
'mean': float(np.mean(r2_map))},
}
if len(nrmse_samp) > 0:
results[mech]['nrmse_posterior_median'] = {
'median': float(np.median(nrmse_samp)),
'mean': float(np.mean(nrmse_samp)),
}
print(f"\n{mech} ({len(nrmse_mean)} samples):")
print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} "
f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}")
print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} "
f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}")
if len(nrmse_samp) > 0:
print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} "
f"mean={np.mean(nrmse_samp):.4f}")
print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}")
print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}")
with open(eval_dir / "reconstruction_results.json", "w") as f:
json.dump(results, f, indent=2)
if raw_per_mechanism:
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
print(f"\nResults saved to {eval_dir}")
print(f"Visualizations: {sum(vis_count.values())} plots saved")
# =============================================================================
# TPD evaluation
# =============================================================================
def evaluate_tpd(args):
from tpd_model import MultiMechanismFlowTPD
from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS
from dataset_tpd import TPDDataset, collate_fn
from torch.utils.data import DataLoader
ckpt_path = os.path.expanduser(args.checkpoint)
checkpoint = torch.load(ckpt_path, map_location='cpu', weights_only=False)
ckpt_args = checkpoint['args']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MultiMechanismFlowTPD(
d_context=ckpt_args.get('d_context', 128),
d_model=ckpt_args.get('d_model', 128),
n_coupling_layers=ckpt_args.get('n_coupling_layers', 6),
hidden_dim=ckpt_args.get('hidden_dim', 96),
coupling_type=ckpt_args.get('coupling_type', 'spline'),
n_bins=ckpt_args.get('n_bins', 8),
tail_bound=ckpt_args.get('tail_bound', 5.0),
)
ckpt_dir = Path(ckpt_path).parent.parent
theta_stats_path = ckpt_dir / "theta_stats.json"
with open(theta_stats_path) as f:
theta_stats = json.load(f)
for mech in TPD_MECHANISM_LIST:
if mech in theta_stats:
model.set_theta_stats(
mech,
torch.tensor(theta_stats[mech]['mean']),
torch.tensor(theta_stats[mech]['std']),
)
norm_stats_path = ckpt_dir / "norm_stats.json"
with open(norm_stats_path) as f:
norm_stats = json.load(f)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
for m in model.modules():
if hasattr(m, '_initialized') and not m.initialized:
m.initialized = True
model = model.to(device)
model.eval()
if args.data_dir:
data_dir = os.path.expanduser(args.data_dir)
else:
data_dir = os.path.expanduser(ckpt_args.get('data_dir', '~/ECFlow/data_tpd_multiheat'))
split_dir = os.path.join(data_dir, args.split)
raw_per_mechanism = False
if not os.path.exists(split_dir) or not glob.glob(os.path.join(split_dir, "sample_*.npz")):
mech_dirs = [d for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d, args.split))]
if mech_dirs:
raw_per_mechanism = True
print(f"Detected raw per-mechanism directory structure in {data_dir}")
print(f" Mechanisms found: {sorted(mech_dirs)}")
import tempfile
tmp_dir = tempfile.mkdtemp(prefix="ecflow_recon_tpd_")
flat_dir = os.path.join(tmp_dir, args.split)
os.makedirs(flat_dir, exist_ok=True)
file_idx = 0
for mech_name in sorted(mech_dirs):
mech_split = os.path.join(data_dir, mech_name, args.split)
for f in sorted(glob.glob(os.path.join(mech_split, "sample_*.npz"))):
dst = os.path.join(flat_dir, f"sample_{file_idx:06d}.npz")
os.symlink(os.path.abspath(f), dst)
file_idx += 1
split_dir = flat_dir
print(f" Linked {file_idx} samples into temporary flat directory")
print(f"Loading data from: {split_dir}")
dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=True)
dataset.temperature_mean = norm_stats['temperature'][0]
dataset.temperature_std = norm_stats['temperature'][1]
dataset.rate_mean = norm_stats['rate'][0]
dataset.rate_std = norm_stats['rate'][1]
raw_dataset = TPDDataset(split_dir, max_samples=args.max_samples, normalize_input=False)
loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
eval_dir = ckpt_dir / f"eval_recon_{args.split}"
if raw_per_mechanism:
eval_dir = ckpt_dir / f"eval_recon_clean_{args.split}"
eval_dir.mkdir(exist_ok=True)
per_mech_nrmse_mean = defaultdict(list)
per_mech_nrmse_map = defaultdict(list)
per_mech_nrmse_samples = defaultdict(list)
per_mech_r2_mean = defaultdict(list)
per_mech_r2_map = defaultdict(list)
vis_count = defaultdict(int)
print(f"Evaluating TPD signal reconstruction on {len(dataset)} samples...")
for idx, batch in enumerate(tqdm(loader, desc="Reconstructing")):
x = batch['input'].to(device)
scan_mask = batch['scan_mask'].to(device)
sigmas_tensor = batch['sigmas'].to(device)
flux_scales = batch['flux_scales'].to(device)
mech_id = batch['mechanism_id'].item()
if mech_id < 0 or mech_id >= len(TPD_MECHANISM_LIST):
continue
mech = TPD_MECHANISM_LIST[mech_id]
raw_data = np.load(raw_dataset.sample_files[idx], allow_pickle=True)
raw_params = raw_data['params'].item()
raw_rate = raw_data['rate'].astype(np.float32)
raw_temp = raw_data['temperature'].astype(np.float32)
if 'heating_rates' in raw_data:
betas = raw_data['heating_rates'].astype(np.float64)
lengths = raw_data['lengths'].astype(int)
else:
betas = np.array([raw_params.get('beta', 1.0)])
lengths = np.array([len(raw_rate)])
raw_rate = raw_rate[np.newaxis, :]
raw_temp = raw_temp[np.newaxis, :]
with torch.no_grad():
pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas_tensor,
flux_scales=flux_scales,
n_samples=args.n_posterior_samples,
temperature=args.temperature)
if pred['stats'][mech] is None:
continue
theta_mean = pred['stats'][mech]['mean'][0].cpu().numpy()
samples = pred['samples'][mech][0].cpu().numpy()
from scipy.stats import gaussian_kde
theta_map = np.zeros_like(theta_mean)
for d in range(len(theta_mean)):
s = samples[:, d]
if np.std(s) < 1e-10:
theta_map[d] = np.mean(s)
else:
try:
kde = gaussian_kde(s)
grid = np.linspace(s.min(), s.max(), 200)
theta_map[d] = grid[np.argmax(kde(grid))]
except Exception:
theta_map[d] = np.median(s)
base_params = dict(raw_params)
recon_mean = reconstruct_tpd_signal(theta_mean, mech, base_params, betas)
recon_map = reconstruct_tpd_signal(theta_map, mech, base_params, betas)
nrmse_mean_list = []
nrmse_map_list = []
r2_mean_list = []
r2_map_list = []
for s_idx in range(len(betas)):
obs_rate = raw_rate[s_idx]
length = lengths[s_idx]
if recon_mean[s_idx]['success']:
v = signal_nrmse(obs_rate, recon_mean[s_idx]['rate'], length)
r = signal_r2(obs_rate, recon_mean[s_idx]['rate'], length)
if np.isfinite(v):
nrmse_mean_list.append(v)
if np.isfinite(r):
r2_mean_list.append(r)
if recon_map[s_idx]['success']:
v = signal_nrmse(obs_rate, recon_map[s_idx]['rate'], length)
r = signal_r2(obs_rate, recon_map[s_idx]['rate'], length)
if np.isfinite(v):
nrmse_map_list.append(v)
if np.isfinite(r):
r2_map_list.append(r)
if nrmse_mean_list:
per_mech_nrmse_mean[mech].append(np.mean(nrmse_mean_list))
if r2_mean_list:
per_mech_r2_mean[mech].append(np.mean(r2_mean_list))
if nrmse_map_list:
per_mech_nrmse_map[mech].append(np.mean(nrmse_map_list))
if r2_map_list:
per_mech_r2_map[mech].append(np.mean(r2_map_list))
sample_nrmses = []
n_recon = min(args.n_recon_samples, samples.shape[0])
sample_indices = np.random.choice(samples.shape[0], n_recon, replace=False)
for si in sample_indices:
recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, betas)
nrmses = []
for s_idx in range(len(betas)):
if recon_s[s_idx]['success']:
v = signal_nrmse(raw_rate[s_idx], recon_s[s_idx]['rate'], lengths[s_idx])
if np.isfinite(v):
nrmses.append(v)
if nrmses:
sample_nrmses.append(np.mean(nrmses))
if sample_nrmses:
per_mech_nrmse_samples[mech].append(np.median(sample_nrmses))
# Visualization
if vis_count[mech] < args.n_visualize and recon_mean[0]['success']:
fig, axes = plt.subplots(1, len(betas), figsize=(5 * len(betas), 4))
if len(betas) == 1:
axes = [axes]
for s_idx, ax in enumerate(axes):
length = lengths[s_idx]
obs_temp = raw_temp[s_idx, :length]
obs_rate_s = raw_rate[s_idx, :length]
ax.plot(obs_temp, obs_rate_s, 'k-', lw=1.5, label='Observed', alpha=0.8)
if recon_mean[s_idx]['success']:
r_temp = recon_mean[s_idx]['temperature']
r_rate = recon_mean[s_idx]['rate']
min_len = min(length, len(r_temp))
nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length)
lbl = f'Mean (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'Mean (NRMSE=N/A)'
ax.plot(r_temp[:min_len], r_rate[:min_len], 'r--', lw=1.2, label=lbl)
if recon_map[s_idx]['success']:
r_temp = recon_map[s_idx]['temperature']
r_rate = recon_map[s_idx]['rate']
min_len = min(length, len(r_temp))
nrmse_val = signal_nrmse(obs_rate_s, r_rate[:length] if length <= len(r_rate) else r_rate, length)
lbl = f'MAP (NRMSE={nrmse_val:.3f})' if np.isfinite(nrmse_val) else 'MAP (NRMSE=N/A)'
ax.plot(r_temp[:min_len], r_rate[:min_len], 'b:', lw=1.2, label=lbl)
for si in sample_indices[:3]:
recon_s = reconstruct_tpd_signal(samples[si], mech, base_params, [betas[s_idx]])
if recon_s[0]['success']:
r_temp = recon_s[0]['temperature']
r_rate = recon_s[0]['rate']
min_len = min(length, len(r_temp))
ax.plot(r_temp[:min_len], r_rate[:min_len], '-', lw=0.5,
alpha=0.3, color='gray')
ax.set_xlabel('Temperature (K)')
ax.set_ylabel('Rate')
ax.set_title(f'β={betas[s_idx]:.2f} K/s')
ax.legend(fontsize=7)
fig.suptitle(f'{mech} sample {idx}', fontsize=12)
plt.tight_layout()
plt.savefig(eval_dir / f"recon_{mech}_{vis_count[mech]:03d}.png", dpi=150)
plt.close(fig)
vis_count[mech] += 1
print("\n" + "=" * 60)
print("SIGNAL RECONSTRUCTION METRICS (TPD)")
print("=" * 60)
results = {}
for mech in TPD_MECHANISM_LIST:
if not per_mech_nrmse_mean[mech]:
continue
nrmse_mean = np.array(per_mech_nrmse_mean[mech])
nrmse_map = np.array(per_mech_nrmse_map[mech])
nrmse_samp = np.array(per_mech_nrmse_samples[mech]) if per_mech_nrmse_samples[mech] else np.array([])
r2_mean = np.array(per_mech_r2_mean[mech])
r2_map = np.array(per_mech_r2_map[mech])
results[mech] = {
'n_samples': len(nrmse_mean),
'nrmse_mean': {'median': float(np.median(nrmse_mean)),
'mean': float(np.mean(nrmse_mean)),
'std': float(np.std(nrmse_mean))},
'nrmse_map': {'median': float(np.median(nrmse_map)),
'mean': float(np.mean(nrmse_map)),
'std': float(np.std(nrmse_map))},
'r2_signal_mean': {'median': float(np.median(r2_mean)),
'mean': float(np.mean(r2_mean))},
'r2_signal_map': {'median': float(np.median(r2_map)),
'mean': float(np.mean(r2_map))},
}
if len(nrmse_samp) > 0:
results[mech]['nrmse_posterior_median'] = {
'median': float(np.median(nrmse_samp)),
'mean': float(np.mean(nrmse_samp)),
}
print(f"\n{mech} ({len(nrmse_mean)} samples):")
print(f" Signal NRMSE (mean est): median={np.median(nrmse_mean):.4f} "
f"mean={np.mean(nrmse_mean):.4f} ± {np.std(nrmse_mean):.4f}")
print(f" Signal NRMSE (MAP est): median={np.median(nrmse_map):.4f} "
f"mean={np.mean(nrmse_map):.4f} ± {np.std(nrmse_map):.4f}")
if len(nrmse_samp) > 0:
print(f" Signal NRMSE (post. med): median={np.median(nrmse_samp):.4f} "
f"mean={np.mean(nrmse_samp):.4f}")
print(f" Signal R² (mean est): median={np.median(r2_mean):.4f}")
print(f" Signal R² (MAP est): median={np.median(r2_map):.4f}")
with open(eval_dir / "reconstruction_results.json", "w") as f:
json.dump(results, f, indent=2)
if raw_per_mechanism:
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
print(f"\nResults saved to {eval_dir}")
print(f"Visualizations: {sum(vis_count.values())} plots saved")
if __name__ == "__main__":
args = parse_args()
if args.domain == "ec":
evaluate_ec(args)
else:
evaluate_tpd(args)