""" ECFlow inference engine. Loads trained EC and TPD models and runs end-to-end inference from preprocessed arrays (dimensionless for CV, physical for TPD). """ import json import sys import os from pathlib import Path import numpy as np import torch from multi_mechanism_model import MultiMechanismFlow from tpd_model import MultiMechanismFlowTPD from flow_model import MECHANISM_LIST, MECHANISM_PARAMS, ActNorm from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS def _fix_actnorm_initialized(model): """Mark all ActNorm layers as initialized after loading a checkpoint. Old checkpoints lack the ``_initialized`` buffer, so ``load_state_dict`` leaves it at ``False``. The first forward pass would then overwrite the trained ``log_scale``/``bias`` with data-dependent statistics. """ for module in model.modules(): if isinstance(module, ActNorm) and not module.initialized: module.initialized = True class ECFlowPredictor: """Unified predictor for both EC (cyclic voltammetry) and TPD domains.""" def __init__(self, ec_checkpoint=None, tpd_checkpoint=None, device=None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.ec_model = None self.ec_norm_stats = None self.tpd_model = None self.tpd_norm_stats = None if ec_checkpoint is not None: self._load_ec(ec_checkpoint) if tpd_checkpoint is not None: self._load_tpd(tpd_checkpoint) def _load_ec(self, ckpt_path): ckpt_path = Path(ckpt_path) checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) args = checkpoint["args"] self.ec_model = MultiMechanismFlow( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), ) self.ec_model.load_state_dict(checkpoint["model_state_dict"], strict=False) _fix_actnorm_initialized(self.ec_model) self.ec_model.to(self.device).eval() # Search for norm_stats in multiple locations ckpt_dir = ckpt_path.parent stem = ckpt_path.stem.replace("best", "").rstrip("_") prefix = stem + "_" if stem else "" for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}norm_stats.json", "ec_norm_stats.json", "norm_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.ec_norm_stats = json.load(f) break if self.ec_norm_stats is not None: break for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}theta_stats.json", "ec_theta_stats.json", "theta_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.ec_theta_stats = json.load(f) break if hasattr(self, "ec_theta_stats") and self.ec_theta_stats is not None: break def _load_tpd(self, ckpt_path): ckpt_path = Path(ckpt_path) checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) args = checkpoint["args"] self.tpd_use_summary = args.get("use_summary_features", False) self.tpd_model = MultiMechanismFlowTPD( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), use_summary_features=self.tpd_use_summary, ) self.tpd_model.load_state_dict(checkpoint["model_state_dict"], strict=False) _fix_actnorm_initialized(self.tpd_model) self.tpd_model.to(self.device).eval() # Search for norm_stats in multiple locations ckpt_dir = ckpt_path.parent stem = ckpt_path.stem.replace("best", "").rstrip("_") prefix = stem + "_" if stem else "" for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}norm_stats.json", "tpd_norm_stats.json", "norm_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.tpd_norm_stats = json.load(f) break if self.tpd_norm_stats is not None: break for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}theta_stats.json", "tpd_theta_stats.json", "theta_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.tpd_theta_stats = json.load(f) break if hasattr(self, "tpd_theta_stats") and self.tpd_theta_stats is not None: break def _prepare_ec_tensor(self, potentials, fluxes, times, sigmas): """ Build model input tensor from preprocessed dimensionless CV data. Args: potentials: list of 1-D arrays (dimensionless theta) fluxes: list of 1-D arrays (dimensionless flux) times: list of 1-D arrays (dimensionless time) or None sigmas: 1-D array of dimensionless scan rates Returns: dict of tensors ready for model.predict() """ from scipy.interpolate import interp1d n_scans = len(potentials) T_target = 672 pot_resampled = [] flux_resampled = [] time_resampled = [] flux_scales = [] for i in range(n_scans): pot = np.asarray(potentials[i], dtype=np.float32) flx = np.asarray(fluxes[i], dtype=np.float32) if times is not None and times[i] is not None: tim = np.asarray(times[i], dtype=np.float32) else: theta_range = pot.max() - pot.min() sigma = sigmas[i] total_time = 2.0 * theta_range / sigma tim = np.linspace(0, total_time, len(pot), dtype=np.float32) peak = np.max(np.abs(flx)) + 1e-30 flux_scales.append(np.log10(peak)) flx = flx / peak t_uniform = np.linspace(tim[0], tim[-1], T_target) pot_resampled.append( interp1d(tim, pot, kind="linear", fill_value="extrapolate")(t_uniform) ) flux_resampled.append( interp1d(tim, flx, kind="linear", fill_value="extrapolate")(t_uniform) ) time_resampled.append(t_uniform) pot_arr = np.stack(pot_resampled).astype(np.float32) flx_arr = np.stack(flux_resampled).astype(np.float32) tim_arr = np.stack(time_resampled).astype(np.float32) ns = self.ec_norm_stats if ns: pot_arr = (pot_arr - ns["potential"][0]) / ns["potential"][1] flx_arr = (flx_arr - ns["flux"][0]) / ns["flux"][1] tim_arr = (tim_arr - ns["time"][0]) / ns["time"][1] # [1, N, 3, T] waveforms = np.stack([pot_arr, flx_arr, tim_arr], axis=1) x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) scan_mask = torch.ones(1, n_scans, T_target, dtype=torch.bool, device=self.device) sigmas_t = torch.from_numpy( np.log10(np.asarray(sigmas, dtype=np.float32)) ).unsqueeze(0).to(self.device) flux_scales_t = torch.from_numpy( np.asarray(flux_scales, dtype=np.float32) ).unsqueeze(0).to(self.device) return { "input": x, "scan_mask": scan_mask, "sigmas": sigmas_t, "flux_scales": flux_scales_t, } def _prepare_tpd_tensor(self, temperatures, rates, betas): """ Build model input tensor from TPD data. Args: temperatures: list of 1-D arrays (K) rates: list of 1-D arrays (arb. units) betas: 1-D array of heating rates (K/s) Returns: dict of tensors ready for model.predict() """ from scipy.interpolate import interp1d n_rates = len(temperatures) T_target = 500 temp_resampled = [] rate_resampled = [] for i in range(n_rates): temp = np.asarray(temperatures[i], dtype=np.float32) rate = np.asarray(rates[i], dtype=np.float32) t_uniform = np.linspace(temp[0], temp[-1], T_target) temp_resampled.append(t_uniform) rate_resampled.append( interp1d(temp, rate, kind="linear", fill_value="extrapolate")(t_uniform) ) temp_arr = np.stack(temp_resampled).astype(np.float32) rate_arr = np.stack(rate_resampled).astype(np.float32) summary_t = None if getattr(self, 'tpd_use_summary', False): from preprocessing import extract_tpd_summary_stats hr_arr = np.asarray(betas, dtype=np.float32) lengths = np.full(n_rates, T_target, dtype=np.int32) summary = extract_tpd_summary_stats( temp_arr, rate_arr, lengths, hr_arr, n_rates) summary_t = torch.from_numpy(summary).unsqueeze(0).to(self.device) rate_scales = [] for i in range(n_rates): peak = np.max(np.abs(rate_arr[i])) + 1e-30 rate_scales.append(np.log10(peak)) rate_arr[i] /= peak ns = self.tpd_norm_stats if ns: temp_arr = (temp_arr - ns["temperature"][0]) / ns["temperature"][1] rate_arr = (rate_arr - ns["rate"][0]) / ns["rate"][1] # [1, N, 2, T] waveforms = np.stack([temp_arr, rate_arr], axis=1) x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) scan_mask = torch.ones(1, n_rates, T_target, dtype=torch.bool, device=self.device) sigmas_t = torch.from_numpy( np.log10(np.asarray(betas, dtype=np.float32)) ).unsqueeze(0).to(self.device) rate_scales_t = torch.from_numpy( np.asarray(rate_scales, dtype=np.float32) ).unsqueeze(0).to(self.device) result = { "input": x, "scan_mask": scan_mask, "sigmas": sigmas_t, "flux_scales": rate_scales_t, } if summary_t is not None: result["summary"] = summary_t return result @torch.no_grad() def predict_ec(self, potentials, fluxes, sigmas, times=None, n_samples=500, temperature=1.0): """ Run EC inference on dimensionless CV data. Args: potentials: list of 1-D arrays (dimensionless theta per scan rate) fluxes: list of 1-D arrays (dimensionless flux per scan rate) sigmas: list/array of dimensionless scan rates times: optional list of 1-D time arrays n_samples: posterior samples to draw temperature: sampling temperature (>1 broadens posteriors) Returns: dict with mechanism_probs, mechanism_names, predicted_mechanism, parameter_stats (per mechanism), posterior_samples (per mechanism) """ if self.ec_model is None: raise RuntimeError("EC model not loaded") tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas) pred = self.ec_model.predict( tensors["input"], scan_mask=tensors["scan_mask"], sigmas=tensors["sigmas"], flux_scales=tensors["flux_scales"], n_samples=n_samples, temperature=temperature, ) probs = pred["mechanism_probs"][0].cpu().numpy() pred_idx = int(pred["mechanism_pred"][0].cpu().item()) pred_mech = MECHANISM_LIST[pred_idx] param_stats = {} samples_dict = {} for mech in MECHANISM_LIST: if pred["samples"][mech] is not None: s = pred["samples"][mech][0].cpu().numpy() # [n_samples, D] samples_dict[mech] = s param_stats[mech] = { "names": MECHANISM_PARAMS[mech]["names"], "mean": s.mean(axis=0).tolist(), "std": s.std(axis=0).tolist(), "median": np.median(s, axis=0).tolist(), "q05": np.quantile(s, 0.05, axis=0).tolist(), "q95": np.quantile(s, 0.95, axis=0).tolist(), } return { "domain": "ec", "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(MECHANISM_LIST)}, "mechanism_names": MECHANISM_LIST, "predicted_mechanism": pred_mech, "predicted_mechanism_idx": pred_idx, "parameter_stats": param_stats, "posterior_samples": samples_dict, } @torch.no_grad() def predict_tpd(self, temperatures, rates, betas, n_samples=500, temperature=1.0): """ Run TPD inference. Args: temperatures: list of 1-D arrays (K per heating rate) rates: list of 1-D arrays (signal per heating rate) betas: list/array of heating rates (K/s) n_samples: posterior samples to draw temperature: sampling temperature Returns: dict with mechanism_probs, parameter_stats, posterior_samples """ if self.tpd_model is None: raise RuntimeError("TPD model not loaded") tensors = self._prepare_tpd_tensor(temperatures, rates, betas) pred = self.tpd_model.predict( tensors["input"], scan_mask=tensors["scan_mask"], sigmas=tensors["sigmas"], flux_scales=tensors["flux_scales"], n_samples=n_samples, temperature=temperature, summary=tensors.get("summary"), ) probs = pred["mechanism_probs"][0].cpu().numpy() pred_idx = int(pred["mechanism_pred"][0].cpu().item()) pred_mech = TPD_MECHANISM_LIST[pred_idx] param_stats = {} samples_dict = {} for mech in TPD_MECHANISM_LIST: if pred["samples"][mech] is not None: s = pred["samples"][mech][0].cpu().numpy() samples_dict[mech] = s param_stats[mech] = { "names": TPD_MECHANISM_PARAMS[mech]["names"], "mean": s.mean(axis=0).tolist(), "std": s.std(axis=0).tolist(), "median": np.median(s, axis=0).tolist(), "q05": np.quantile(s, 0.05, axis=0).tolist(), "q95": np.quantile(s, 0.95, axis=0).tolist(), } return { "domain": "tpd", "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(TPD_MECHANISM_LIST)}, "mechanism_names": TPD_MECHANISM_LIST, "predicted_mechanism": pred_mech, "predicted_mechanism_idx": pred_idx, "parameter_stats": param_stats, "posterior_samples": samples_dict, } # ===================================================================== # Signal Reconstruction # ===================================================================== def reconstruct_ec(self, result, potentials, fluxes, sigmas, base_params=None, mechanism=None): """ Reconstruct CV signals from inferred posterior median and compute metrics. Args: result: output dict from predict_ec() potentials: list of 1-D arrays (original dimensionless theta) fluxes: list of 1-D arrays (original dimensionless flux) sigmas: list of dimensionless scan rates base_params: dict of fixed simulation params; defaults used if None mechanism: which mechanism to reconstruct (default: predicted) Returns: dict with 'observed', 'reconstructed' curve lists, 'nrmse', 'r2' per scan rate, and 'mean_nrmse', 'mean_r2' """ from evaluate_reconstruction import ( reconstruct_ec_signal, signal_nrmse, signal_r2, ) mech = mechanism or result["predicted_mechanism"] stats = result["parameter_stats"].get(mech) if stats is None: return None theta_point = np.array(stats["median"]) if base_params is None: pot0 = np.asarray(potentials[0]) base_params = { "theta_i": float(pot0.max()), "theta_v": float(pot0.min()), "dA": 1.0, "C_A_bulk": 1.0, "C_B_bulk": 0.0, "kinetics": mech, } try: recon_results = reconstruct_ec_signal( theta_point, mech, base_params, sigmas, n_spatial=64 ) except Exception: return None observed_curves = [] recon_curves = [] conc_curves = [] nrmses = [] r2s = [] for i, (pot, flx, sigma) in enumerate(zip(potentials, fluxes, sigmas)): pot = np.asarray(pot) flx = np.asarray(flx) observed_curves.append({"x": pot, "y": flx}) if i < len(recon_results) and recon_results[i].get("success", False): rec = recon_results[i] rec_pot = np.asarray(rec["potential"]) rec_flx = np.asarray(rec["flux"]) n_obs = len(pot) n_rec = len(rec_pot) t_obs = np.linspace(0, 1, n_obs) t_rec = np.linspace(0, 1, n_rec) rec_flx_interp = np.interp(t_obs, t_rec, rec_flx) recon_curves.append({"x": pot, "y": rec_flx_interp}) nrmse_val = signal_nrmse(flx, rec_flx_interp) r2_val = signal_r2(flx, rec_flx_interp) nrmses.append(nrmse_val) r2s.append(r2_val) if "c_ox_surface" in rec and "c_red_surface" in rec: c_ox_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_ox_surface"])) c_red_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_red_surface"])) conc_curves.append({ "x": pot, "c_ox": c_ox_interp, "c_red": c_red_interp, }) else: conc_curves.append(None) else: recon_curves.append({"x": pot, "y": np.zeros_like(flx)}) nrmses.append(float("nan")) r2s.append(float("nan")) conc_curves.append(None) valid_nrmse = [v for v in nrmses if np.isfinite(v)] valid_r2 = [v for v in r2s if np.isfinite(v)] return { "observed": observed_curves, "reconstructed": recon_curves, "concentrations": conc_curves, "nrmse": nrmses, "r2": r2s, "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), } def reconstruct_tpd(self, result, temperatures, rates, betas, base_params=None, mechanism=None): """ Reconstruct TPD signals from inferred posterior median and compute metrics. Args: result: output dict from predict_tpd() temperatures: list of 1-D arrays (K) rates: list of 1-D arrays (signal) betas: list of heating rates (K/s) base_params: dict of fixed simulation params; defaults used if None mechanism: which mechanism to reconstruct (default: predicted) Returns: dict with 'observed', 'reconstructed' curve lists, 'nrmse', 'r2' per heating rate, and 'mean_nrmse', 'mean_r2' """ from evaluate_reconstruction import ( reconstruct_tpd_signal, signal_nrmse, signal_r2, ) mech = mechanism or result["predicted_mechanism"] stats = result["parameter_stats"].get(mech) if stats is None: return None theta_point = np.array(stats["median"]) if base_params is None: temp0 = np.asarray(temperatures[0]) base_params = { "mechanism": mech, "T_start": float(temp0.min()), "T_end": float(temp0.max()), "n_points": 500, } try: recon_results = reconstruct_tpd_signal( theta_point, mech, base_params, betas ) except Exception: return None observed_curves = [] recon_curves = [] nrmses = [] r2s = [] for i, (temp, rate, beta) in enumerate(zip(temperatures, rates, betas)): temp = np.asarray(temp) rate = np.asarray(rate) observed_curves.append({"x": temp, "y": rate}) if i < len(recon_results) and recon_results[i].get("success", False): rec = recon_results[i] rec_temp = np.asarray(rec["temperature"]) rec_rate = np.asarray(rec["rate"]) rec_rate_interp = np.interp(temp, rec_temp, rec_rate) recon_curves.append({"x": temp, "y": rec_rate_interp}) nrmse_val = signal_nrmse(rate, rec_rate_interp) r2_val = signal_r2(rate, rec_rate_interp) nrmses.append(nrmse_val) r2s.append(r2_val) else: recon_curves.append({"x": temp, "y": np.zeros_like(rate)}) nrmses.append(float("nan")) r2s.append(float("nan")) valid_nrmse = [v for v in nrmses if np.isfinite(v)] valid_r2 = [v for v in r2s if np.isfinite(v)] return { "observed": observed_curves, "reconstructed": recon_curves, "nrmse": nrmses, "r2": r2s, "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), }