| """ |
| ECFlow inference engine. |
| |
| Loads trained EC and TPD models and runs end-to-end inference |
| from preprocessed arrays (dimensionless for CV, physical for TPD). |
| """ |
|
|
| import json |
| import sys |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
|
|
| from multi_mechanism_model import MultiMechanismFlow |
| from tpd_model import MultiMechanismFlowTPD |
| from flow_model import MECHANISM_LIST, MECHANISM_PARAMS, ActNorm |
| from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS |
|
|
|
|
| def _fix_actnorm_initialized(model): |
| """Mark all ActNorm layers as initialized after loading a checkpoint. |
| |
| Old checkpoints lack the ``_initialized`` buffer, so ``load_state_dict`` |
| leaves it at ``False``. The first forward pass would then overwrite the |
| trained ``log_scale``/``bias`` with data-dependent statistics. |
| """ |
| for module in model.modules(): |
| if isinstance(module, ActNorm) and not module.initialized: |
| module.initialized = True |
|
|
|
|
| class ECFlowPredictor: |
| """Unified predictor for both EC (cyclic voltammetry) and TPD domains.""" |
|
|
| def __init__(self, ec_checkpoint=None, tpd_checkpoint=None, device=None): |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.device = device |
|
|
| self.ec_model = None |
| self.ec_norm_stats = None |
| self.tpd_model = None |
| self.tpd_norm_stats = None |
|
|
| if ec_checkpoint is not None: |
| self._load_ec(ec_checkpoint) |
| if tpd_checkpoint is not None: |
| self._load_tpd(tpd_checkpoint) |
|
|
| def _load_ec(self, ckpt_path): |
| ckpt_path = Path(ckpt_path) |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| args = checkpoint["args"] |
|
|
| self.ec_model = MultiMechanismFlow( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| ) |
| self.ec_model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| _fix_actnorm_initialized(self.ec_model) |
| self.ec_model.to(self.device).eval() |
|
|
| |
| ckpt_dir = ckpt_path.parent |
| stem = ckpt_path.stem.replace("best", "").rstrip("_") |
| prefix = stem + "_" if stem else "" |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}norm_stats.json", "ec_norm_stats.json", "norm_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.ec_norm_stats = json.load(f) |
| break |
| if self.ec_norm_stats is not None: |
| break |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}theta_stats.json", "ec_theta_stats.json", "theta_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.ec_theta_stats = json.load(f) |
| break |
| if hasattr(self, "ec_theta_stats") and self.ec_theta_stats is not None: |
| break |
|
|
| def _load_tpd(self, ckpt_path): |
| ckpt_path = Path(ckpt_path) |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| args = checkpoint["args"] |
|
|
| self.tpd_use_summary = args.get("use_summary_features", False) |
|
|
| self.tpd_model = MultiMechanismFlowTPD( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| use_summary_features=self.tpd_use_summary, |
| ) |
| self.tpd_model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
| _fix_actnorm_initialized(self.tpd_model) |
| self.tpd_model.to(self.device).eval() |
|
|
| |
| ckpt_dir = ckpt_path.parent |
| stem = ckpt_path.stem.replace("best", "").rstrip("_") |
| prefix = stem + "_" if stem else "" |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}norm_stats.json", "tpd_norm_stats.json", "norm_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.tpd_norm_stats = json.load(f) |
| break |
| if self.tpd_norm_stats is not None: |
| break |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}theta_stats.json", "tpd_theta_stats.json", "theta_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.tpd_theta_stats = json.load(f) |
| break |
| if hasattr(self, "tpd_theta_stats") and self.tpd_theta_stats is not None: |
| break |
|
|
| def _prepare_ec_tensor(self, potentials, fluxes, times, sigmas): |
| """ |
| Build model input tensor from preprocessed dimensionless CV data. |
| |
| Args: |
| potentials: list of 1-D arrays (dimensionless theta) |
| fluxes: list of 1-D arrays (dimensionless flux) |
| times: list of 1-D arrays (dimensionless time) or None |
| sigmas: 1-D array of dimensionless scan rates |
| |
| Returns: |
| dict of tensors ready for model.predict() |
| """ |
| from scipy.interpolate import interp1d |
|
|
| n_scans = len(potentials) |
| T_target = 672 |
|
|
| pot_resampled = [] |
| flux_resampled = [] |
| time_resampled = [] |
| flux_scales = [] |
|
|
| for i in range(n_scans): |
| pot = np.asarray(potentials[i], dtype=np.float32) |
| flx = np.asarray(fluxes[i], dtype=np.float32) |
|
|
| if times is not None and times[i] is not None: |
| tim = np.asarray(times[i], dtype=np.float32) |
| else: |
| theta_range = pot.max() - pot.min() |
| sigma = sigmas[i] |
| total_time = 2.0 * theta_range / sigma |
| tim = np.linspace(0, total_time, len(pot), dtype=np.float32) |
|
|
| peak = np.max(np.abs(flx)) + 1e-30 |
| flux_scales.append(np.log10(peak)) |
| flx = flx / peak |
|
|
| t_uniform = np.linspace(tim[0], tim[-1], T_target) |
| pot_resampled.append( |
| interp1d(tim, pot, kind="linear", fill_value="extrapolate")(t_uniform) |
| ) |
| flux_resampled.append( |
| interp1d(tim, flx, kind="linear", fill_value="extrapolate")(t_uniform) |
| ) |
| time_resampled.append(t_uniform) |
|
|
| pot_arr = np.stack(pot_resampled).astype(np.float32) |
| flx_arr = np.stack(flux_resampled).astype(np.float32) |
| tim_arr = np.stack(time_resampled).astype(np.float32) |
|
|
| ns = self.ec_norm_stats |
| if ns: |
| pot_arr = (pot_arr - ns["potential"][0]) / ns["potential"][1] |
| flx_arr = (flx_arr - ns["flux"][0]) / ns["flux"][1] |
| tim_arr = (tim_arr - ns["time"][0]) / ns["time"][1] |
|
|
| |
| waveforms = np.stack([pot_arr, flx_arr, tim_arr], axis=1) |
| x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) |
| scan_mask = torch.ones(1, n_scans, T_target, dtype=torch.bool, device=self.device) |
| sigmas_t = torch.from_numpy( |
| np.log10(np.asarray(sigmas, dtype=np.float32)) |
| ).unsqueeze(0).to(self.device) |
| flux_scales_t = torch.from_numpy( |
| np.asarray(flux_scales, dtype=np.float32) |
| ).unsqueeze(0).to(self.device) |
|
|
| return { |
| "input": x, |
| "scan_mask": scan_mask, |
| "sigmas": sigmas_t, |
| "flux_scales": flux_scales_t, |
| } |
|
|
| def _prepare_tpd_tensor(self, temperatures, rates, betas): |
| """ |
| Build model input tensor from TPD data. |
| |
| Args: |
| temperatures: list of 1-D arrays (K) |
| rates: list of 1-D arrays (arb. units) |
| betas: 1-D array of heating rates (K/s) |
| |
| Returns: |
| dict of tensors ready for model.predict() |
| """ |
| from scipy.interpolate import interp1d |
|
|
| n_rates = len(temperatures) |
| T_target = 500 |
|
|
| temp_resampled = [] |
| rate_resampled = [] |
|
|
| for i in range(n_rates): |
| temp = np.asarray(temperatures[i], dtype=np.float32) |
| rate = np.asarray(rates[i], dtype=np.float32) |
|
|
| t_uniform = np.linspace(temp[0], temp[-1], T_target) |
| temp_resampled.append(t_uniform) |
| rate_resampled.append( |
| interp1d(temp, rate, kind="linear", fill_value="extrapolate")(t_uniform) |
| ) |
|
|
| temp_arr = np.stack(temp_resampled).astype(np.float32) |
| rate_arr = np.stack(rate_resampled).astype(np.float32) |
|
|
| summary_t = None |
| if getattr(self, 'tpd_use_summary', False): |
| from preprocessing import extract_tpd_summary_stats |
| hr_arr = np.asarray(betas, dtype=np.float32) |
| lengths = np.full(n_rates, T_target, dtype=np.int32) |
| summary = extract_tpd_summary_stats( |
| temp_arr, rate_arr, lengths, hr_arr, n_rates) |
| summary_t = torch.from_numpy(summary).unsqueeze(0).to(self.device) |
|
|
| rate_scales = [] |
| for i in range(n_rates): |
| peak = np.max(np.abs(rate_arr[i])) + 1e-30 |
| rate_scales.append(np.log10(peak)) |
| rate_arr[i] /= peak |
|
|
| ns = self.tpd_norm_stats |
| if ns: |
| temp_arr = (temp_arr - ns["temperature"][0]) / ns["temperature"][1] |
| rate_arr = (rate_arr - ns["rate"][0]) / ns["rate"][1] |
|
|
| |
| waveforms = np.stack([temp_arr, rate_arr], axis=1) |
| x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) |
| scan_mask = torch.ones(1, n_rates, T_target, dtype=torch.bool, device=self.device) |
| sigmas_t = torch.from_numpy( |
| np.log10(np.asarray(betas, dtype=np.float32)) |
| ).unsqueeze(0).to(self.device) |
| rate_scales_t = torch.from_numpy( |
| np.asarray(rate_scales, dtype=np.float32) |
| ).unsqueeze(0).to(self.device) |
|
|
| result = { |
| "input": x, |
| "scan_mask": scan_mask, |
| "sigmas": sigmas_t, |
| "flux_scales": rate_scales_t, |
| } |
| if summary_t is not None: |
| result["summary"] = summary_t |
| return result |
|
|
| @torch.no_grad() |
| def predict_ec(self, potentials, fluxes, sigmas, times=None, n_samples=500, temperature=1.0): |
| """ |
| Run EC inference on dimensionless CV data. |
| |
| Args: |
| potentials: list of 1-D arrays (dimensionless theta per scan rate) |
| fluxes: list of 1-D arrays (dimensionless flux per scan rate) |
| sigmas: list/array of dimensionless scan rates |
| times: optional list of 1-D time arrays |
| n_samples: posterior samples to draw |
| temperature: sampling temperature (>1 broadens posteriors) |
| |
| Returns: |
| dict with mechanism_probs, mechanism_names, predicted_mechanism, |
| parameter_stats (per mechanism), posterior_samples (per mechanism) |
| """ |
| if self.ec_model is None: |
| raise RuntimeError("EC model not loaded") |
|
|
| tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas) |
| pred = self.ec_model.predict( |
| tensors["input"], |
| scan_mask=tensors["scan_mask"], |
| sigmas=tensors["sigmas"], |
| flux_scales=tensors["flux_scales"], |
| n_samples=n_samples, |
| temperature=temperature, |
| ) |
|
|
| probs = pred["mechanism_probs"][0].cpu().numpy() |
| pred_idx = int(pred["mechanism_pred"][0].cpu().item()) |
| pred_mech = MECHANISM_LIST[pred_idx] |
|
|
| param_stats = {} |
| samples_dict = {} |
| for mech in MECHANISM_LIST: |
| if pred["samples"][mech] is not None: |
| s = pred["samples"][mech][0].cpu().numpy() |
| samples_dict[mech] = s |
| param_stats[mech] = { |
| "names": MECHANISM_PARAMS[mech]["names"], |
| "mean": s.mean(axis=0).tolist(), |
| "std": s.std(axis=0).tolist(), |
| "median": np.median(s, axis=0).tolist(), |
| "q05": np.quantile(s, 0.05, axis=0).tolist(), |
| "q95": np.quantile(s, 0.95, axis=0).tolist(), |
| } |
|
|
| return { |
| "domain": "ec", |
| "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(MECHANISM_LIST)}, |
| "mechanism_names": MECHANISM_LIST, |
| "predicted_mechanism": pred_mech, |
| "predicted_mechanism_idx": pred_idx, |
| "parameter_stats": param_stats, |
| "posterior_samples": samples_dict, |
| } |
|
|
| @torch.no_grad() |
| def predict_tpd(self, temperatures, rates, betas, n_samples=500, temperature=1.0): |
| """ |
| Run TPD inference. |
| |
| Args: |
| temperatures: list of 1-D arrays (K per heating rate) |
| rates: list of 1-D arrays (signal per heating rate) |
| betas: list/array of heating rates (K/s) |
| n_samples: posterior samples to draw |
| temperature: sampling temperature |
| |
| Returns: |
| dict with mechanism_probs, parameter_stats, posterior_samples |
| """ |
| if self.tpd_model is None: |
| raise RuntimeError("TPD model not loaded") |
|
|
| tensors = self._prepare_tpd_tensor(temperatures, rates, betas) |
| pred = self.tpd_model.predict( |
| tensors["input"], |
| scan_mask=tensors["scan_mask"], |
| sigmas=tensors["sigmas"], |
| flux_scales=tensors["flux_scales"], |
| n_samples=n_samples, |
| temperature=temperature, |
| summary=tensors.get("summary"), |
| ) |
|
|
| probs = pred["mechanism_probs"][0].cpu().numpy() |
| pred_idx = int(pred["mechanism_pred"][0].cpu().item()) |
| pred_mech = TPD_MECHANISM_LIST[pred_idx] |
|
|
| param_stats = {} |
| samples_dict = {} |
| for mech in TPD_MECHANISM_LIST: |
| if pred["samples"][mech] is not None: |
| s = pred["samples"][mech][0].cpu().numpy() |
| samples_dict[mech] = s |
| param_stats[mech] = { |
| "names": TPD_MECHANISM_PARAMS[mech]["names"], |
| "mean": s.mean(axis=0).tolist(), |
| "std": s.std(axis=0).tolist(), |
| "median": np.median(s, axis=0).tolist(), |
| "q05": np.quantile(s, 0.05, axis=0).tolist(), |
| "q95": np.quantile(s, 0.95, axis=0).tolist(), |
| } |
|
|
| return { |
| "domain": "tpd", |
| "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(TPD_MECHANISM_LIST)}, |
| "mechanism_names": TPD_MECHANISM_LIST, |
| "predicted_mechanism": pred_mech, |
| "predicted_mechanism_idx": pred_idx, |
| "parameter_stats": param_stats, |
| "posterior_samples": samples_dict, |
| } |
|
|
| |
| |
| |
|
|
| def reconstruct_ec(self, result, potentials, fluxes, sigmas, |
| base_params=None, mechanism=None): |
| """ |
| Reconstruct CV signals from inferred posterior median and compute metrics. |
| |
| Args: |
| result: output dict from predict_ec() |
| potentials: list of 1-D arrays (original dimensionless theta) |
| fluxes: list of 1-D arrays (original dimensionless flux) |
| sigmas: list of dimensionless scan rates |
| base_params: dict of fixed simulation params; defaults used if None |
| mechanism: which mechanism to reconstruct (default: predicted) |
| |
| Returns: |
| dict with 'observed', 'reconstructed' curve lists, |
| 'nrmse', 'r2' per scan rate, and 'mean_nrmse', 'mean_r2' |
| """ |
| from evaluate_reconstruction import ( |
| reconstruct_ec_signal, signal_nrmse, signal_r2, |
| ) |
|
|
| mech = mechanism or result["predicted_mechanism"] |
| stats = result["parameter_stats"].get(mech) |
| if stats is None: |
| return None |
|
|
| theta_point = np.array(stats["median"]) |
|
|
| if base_params is None: |
| pot0 = np.asarray(potentials[0]) |
| base_params = { |
| "theta_i": float(pot0.max()), |
| "theta_v": float(pot0.min()), |
| "dA": 1.0, |
| "C_A_bulk": 1.0, |
| "C_B_bulk": 0.0, |
| "kinetics": mech, |
| } |
|
|
| try: |
| recon_results = reconstruct_ec_signal( |
| theta_point, mech, base_params, sigmas, n_spatial=64 |
| ) |
| except Exception: |
| return None |
|
|
| observed_curves = [] |
| recon_curves = [] |
| conc_curves = [] |
| nrmses = [] |
| r2s = [] |
|
|
| for i, (pot, flx, sigma) in enumerate(zip(potentials, fluxes, sigmas)): |
| pot = np.asarray(pot) |
| flx = np.asarray(flx) |
| observed_curves.append({"x": pot, "y": flx}) |
|
|
| if i < len(recon_results) and recon_results[i].get("success", False): |
| rec = recon_results[i] |
| rec_pot = np.asarray(rec["potential"]) |
| rec_flx = np.asarray(rec["flux"]) |
|
|
| n_obs = len(pot) |
| n_rec = len(rec_pot) |
| t_obs = np.linspace(0, 1, n_obs) |
| t_rec = np.linspace(0, 1, n_rec) |
| rec_flx_interp = np.interp(t_obs, t_rec, rec_flx) |
| recon_curves.append({"x": pot, "y": rec_flx_interp}) |
| nrmse_val = signal_nrmse(flx, rec_flx_interp) |
| r2_val = signal_r2(flx, rec_flx_interp) |
| nrmses.append(nrmse_val) |
| r2s.append(r2_val) |
|
|
| if "c_ox_surface" in rec and "c_red_surface" in rec: |
| c_ox_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_ox_surface"])) |
| c_red_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_red_surface"])) |
| conc_curves.append({ |
| "x": pot, |
| "c_ox": c_ox_interp, |
| "c_red": c_red_interp, |
| }) |
| else: |
| conc_curves.append(None) |
| else: |
| recon_curves.append({"x": pot, "y": np.zeros_like(flx)}) |
| nrmses.append(float("nan")) |
| r2s.append(float("nan")) |
| conc_curves.append(None) |
|
|
| valid_nrmse = [v for v in nrmses if np.isfinite(v)] |
| valid_r2 = [v for v in r2s if np.isfinite(v)] |
|
|
| return { |
| "observed": observed_curves, |
| "reconstructed": recon_curves, |
| "concentrations": conc_curves, |
| "nrmse": nrmses, |
| "r2": r2s, |
| "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), |
| "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), |
| } |
|
|
| def reconstruct_tpd(self, result, temperatures, rates, betas, |
| base_params=None, mechanism=None): |
| """ |
| Reconstruct TPD signals from inferred posterior median and compute metrics. |
| |
| Args: |
| result: output dict from predict_tpd() |
| temperatures: list of 1-D arrays (K) |
| rates: list of 1-D arrays (signal) |
| betas: list of heating rates (K/s) |
| base_params: dict of fixed simulation params; defaults used if None |
| mechanism: which mechanism to reconstruct (default: predicted) |
| |
| Returns: |
| dict with 'observed', 'reconstructed' curve lists, |
| 'nrmse', 'r2' per heating rate, and 'mean_nrmse', 'mean_r2' |
| """ |
| from evaluate_reconstruction import ( |
| reconstruct_tpd_signal, signal_nrmse, signal_r2, |
| ) |
|
|
| mech = mechanism or result["predicted_mechanism"] |
| stats = result["parameter_stats"].get(mech) |
| if stats is None: |
| return None |
|
|
| theta_point = np.array(stats["median"]) |
|
|
| if base_params is None: |
| temp0 = np.asarray(temperatures[0]) |
| base_params = { |
| "mechanism": mech, |
| "T_start": float(temp0.min()), |
| "T_end": float(temp0.max()), |
| "n_points": 500, |
| } |
|
|
| try: |
| recon_results = reconstruct_tpd_signal( |
| theta_point, mech, base_params, betas |
| ) |
| except Exception: |
| return None |
|
|
| observed_curves = [] |
| recon_curves = [] |
| nrmses = [] |
| r2s = [] |
|
|
| for i, (temp, rate, beta) in enumerate(zip(temperatures, rates, betas)): |
| temp = np.asarray(temp) |
| rate = np.asarray(rate) |
| observed_curves.append({"x": temp, "y": rate}) |
|
|
| if i < len(recon_results) and recon_results[i].get("success", False): |
| rec = recon_results[i] |
| rec_temp = np.asarray(rec["temperature"]) |
| rec_rate = np.asarray(rec["rate"]) |
|
|
| rec_rate_interp = np.interp(temp, rec_temp, rec_rate) |
| recon_curves.append({"x": temp, "y": rec_rate_interp}) |
| nrmse_val = signal_nrmse(rate, rec_rate_interp) |
| r2_val = signal_r2(rate, rec_rate_interp) |
|
|
| nrmses.append(nrmse_val) |
| r2s.append(r2_val) |
| else: |
| recon_curves.append({"x": temp, "y": np.zeros_like(rate)}) |
| nrmses.append(float("nan")) |
| r2s.append(float("nan")) |
|
|
| valid_nrmse = [v for v in nrmses if np.isfinite(v)] |
| valid_r2 = [v for v in r2s if np.isfinite(v)] |
|
|
| return { |
| "observed": observed_curves, |
| "reconstructed": recon_curves, |
| "nrmse": nrmses, |
| "r2": r2s, |
| "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), |
| "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), |
| } |
|
|