| """ |
| Multi-Mechanism Normalizing Flow for Joint Mechanism Identification |
| and Bayesian Parameter Inference from Multi-Heating-Rate TPD Signals. |
| |
| Architecture mirrors the electrochemistry model (multi_mechanism_model.py) |
| but configured for TPD with 2 input channels (temperature, rate) and |
| 6 catalysis mechanisms. |
| |
| Reuses domain-agnostic components from flow_model.py and |
| multi_mechanism_model.py (SAB, PMA, MechanismClassifier, MechanismFlow). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| from flow_model import ( |
| SignalEncoder, |
| ActNorm, |
| ConditionalSplineCoupling, |
| ConditionalAffineCoupling, |
| ) |
| from multi_mechanism_model import MechanismClassifier, MechanismFlow, SAB, PMA, SummaryProjection, SUMMARY_DIM |
|
|
| from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS |
|
|
|
|
| class MultiScanEncoderTPD(nn.Module): |
| """ |
| Encode a set of multi-heating-rate TPD curves into a single context vector. |
| |
| Architecture (Set Transformer): |
| 1. Shared per-curve CNN encoder -> per-curve embedding |
| 2. Augment with [log10(heating_rate), log10(peak_rate)] |
| 3. SAB: self-attention across heating rates |
| 4. PMA: attention-based pooling to single vector |
| 5. rho MLP: project to final context |
| |
| Input: x [B, N_beta, 2, T], scan_mask [B, N_beta, T], |
| heating_rates [B, N_beta], rate_scales [B, N_beta] |
| Output: context [B, d_context] |
| """ |
| def __init__(self, in_channels=2, d_model=128, d_context=128, n_heads=4): |
| super().__init__() |
| self.per_cv_encoder = SignalEncoder( |
| in_channels=in_channels, d_model=d_model, d_context=d_context, |
| ) |
| self.cv_augment = nn.Sequential( |
| nn.Linear(d_context + 2, d_context), |
| nn.GELU(), |
| ) |
| self.sab = SAB(d_context, n_heads=n_heads) |
| self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1) |
| self.rho = nn.Sequential( |
| nn.Linear(d_context, d_context), |
| nn.GELU(), |
| nn.Linear(d_context, d_context), |
| ) |
|
|
| def forward(self, x, scan_mask=None, sigmas=None, flux_scales=None): |
| """ |
| Args: |
| x: [B, N_beta, 2, T] multi-heating-rate TPD curves |
| scan_mask: [B, N_beta, T] valid timestep mask |
| sigmas: [B, N_beta] log10 heating rates |
| flux_scales: [B, N_beta] log10(peak_rate) per curve |
| |
| Returns: |
| context: [B, d_context] |
| """ |
| B, N, C, T = x.shape |
|
|
| x_flat = x.reshape(B * N, C, T) |
| mask_flat = scan_mask.reshape(B * N, T) if scan_mask is not None else None |
|
|
| h_flat = self.per_cv_encoder(x_flat, mask=mask_flat) |
| h = h_flat.reshape(B, N, -1) |
|
|
| if sigmas is None: |
| sigmas = torch.zeros(B, N, device=x.device) |
| if flux_scales is None: |
| flux_scales = torch.zeros(B, N, device=x.device) |
|
|
| aug_features = torch.stack([sigmas, flux_scales], dim=-1) |
| h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) |
|
|
| if scan_mask is not None: |
| cv_invalid = ~scan_mask.any(dim=-1) |
| else: |
| cv_invalid = None |
|
|
| h = self.sab(h, key_padding_mask=cv_invalid) |
| h = self.pma(h, key_padding_mask=cv_invalid) |
| h = h.squeeze(1) |
|
|
| context = self.rho(h) |
| return context |
|
|
|
|
| class MultiMechanismFlowTPD(nn.Module): |
| """ |
| Joint mechanism identification and parameter inference model for TPD. |
| |
| Combines: |
| - Multi-heating-rate signal encoder (Set Transformer over per-curve embeddings) |
| - Mechanism classifier (6 TPD mechanisms) |
| - Per-mechanism normalizing flow heads |
| |
| If use_summary_features=True, replaces the signal encoder with a simple |
| MLP projection from hand-crafted summary statistics (21-dim) to context |
| space, keeping all other components identical. |
| """ |
| def __init__( |
| self, |
| d_context=128, |
| d_model=128, |
| n_coupling_layers=6, |
| hidden_dim=96, |
| coupling_type='spline', |
| n_bins=8, |
| tail_bound=5.0, |
| use_summary_features=False, |
| ): |
| super().__init__() |
| self.n_mechanisms = len(TPD_MECHANISM_LIST) |
| self.mechanism_list = TPD_MECHANISM_LIST |
| self.d_context = d_context |
| self.use_summary_features = use_summary_features |
|
|
| if use_summary_features: |
| self.summary_proj = SummaryProjection( |
| summary_dim=SUMMARY_DIM, d_context=d_context, |
| ) |
| self.encoder = None |
| else: |
| self.encoder = MultiScanEncoderTPD( |
| in_channels=2, d_model=d_model, d_context=d_context, |
| ) |
| self.summary_proj = None |
|
|
| self.classifier = MechanismClassifier( |
| d_context=d_context, |
| n_mechanisms=self.n_mechanisms, |
| hidden_dim=hidden_dim, |
| ) |
|
|
| self.flow_heads = nn.ModuleDict() |
| for mech in TPD_MECHANISM_LIST: |
| theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] |
| self.flow_heads[mech] = MechanismFlow( |
| theta_dim=theta_dim, |
| d_context=d_context, |
| n_coupling_layers=n_coupling_layers, |
| hidden_dim=hidden_dim, |
| coupling_type=coupling_type, |
| n_bins=n_bins, |
| tail_bound=tail_bound, |
| ) |
|
|
| def set_theta_stats(self, mechanism, mean, std): |
| """Set normalization stats for a specific mechanism's flow head.""" |
| self.flow_heads[mechanism].set_theta_stats(mean, std) |
|
|
| def encode_signal(self, x, scan_mask=None, sigmas=None, flux_scales=None, |
| summary=None): |
| if self.use_summary_features: |
| assert summary is not None, "summary features required in summary mode" |
| return self.summary_proj(summary) |
| return self.encoder(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales) |
|
|
| def forward(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, |
| scan_mask=None, sigmas=None, flux_scales=None, summary=None): |
| """ |
| Compute classification logits and per-sample NLL for the true mechanism. |
| |
| Returns: |
| dict with 'logits' [B, n_mechanisms] and 'nll' [B] |
| """ |
| context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, summary=summary) |
| logits = self.classifier(context) |
|
|
| nll = torch.zeros(x.shape[0], device=x.device) |
|
|
| for m_idx, mech in enumerate(TPD_MECHANISM_LIST): |
| sel = (mechanism_ids == m_idx) |
| if not sel.any(): |
| continue |
|
|
| theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] |
| ctx_m = context[sel] |
| theta_m = mech_theta[sel, :theta_dim] |
|
|
| log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m) |
| bad = ~torch.isfinite(log_p) |
| if bad.any(): |
| log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p) |
| nll[sel] = -log_p |
|
|
| return {'logits': logits, 'nll': nll} |
|
|
| def forward_with_calibration(self, x, mechanism_ids, mech_theta, |
| mech_theta_mask=None, scan_mask=None, |
| sigmas=None, flux_scales=None, |
| cal_n_samples=64, cal_levels=(0.5, 0.9), |
| cal_beta=20.0, summary=None): |
| """Forward pass with additional calibration loss (see MultiMechanismFlow).""" |
| context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, summary=summary) |
| logits = self.classifier(context) |
|
|
| nll = torch.zeros(x.shape[0], device=x.device) |
| cal_losses = [] |
|
|
| for m_idx, mech in enumerate(TPD_MECHANISM_LIST): |
| sel = (mechanism_ids == m_idx) |
| if not sel.any(): |
| continue |
|
|
| theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] |
| ctx_m = context[sel] |
| theta_m = mech_theta[sel, :theta_dim] |
|
|
| log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m) |
| bad = ~torch.isfinite(log_p) |
| if bad.any(): |
| log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p) |
| nll[sel] = -log_p |
|
|
| if ctx_m.shape[0] < 4: |
| continue |
| samples = self.flow_heads[mech].sample_with_grad( |
| ctx_m, n_samples=cal_n_samples, |
| ) |
|
|
| |
| |
| with torch.no_grad(): |
| param_std = samples.std(dim=1).clamp(min=1e-4) |
| inv_spread_w = 1.0 / param_std |
| inv_spread_w = inv_spread_w / inv_spread_w.mean() |
|
|
| for level in cal_levels: |
| alpha = (1.0 - level) / 2.0 |
| lower = torch.quantile(samples, alpha, dim=1) |
| upper = torch.quantile(samples, 1 - alpha, dim=1) |
|
|
| inside = ( |
| torch.sigmoid(cal_beta * (theta_m - lower)) |
| * torch.sigmoid(cal_beta * (upper - theta_m)) |
| ) |
|
|
| per_sample_loss = (inside - level).pow(2) |
| cal_losses.append((per_sample_loss * inv_spread_w).mean()) |
|
|
| if cal_losses: |
| cal_loss = torch.stack(cal_losses).mean() |
| else: |
| cal_loss = torch.tensor(0.0, device=x.device) |
|
|
| return {'logits': logits, 'nll': nll, 'cal_loss': cal_loss} |
|
|
| @torch.no_grad() |
| def predict(self, x, scan_mask=None, sigmas=None, flux_scales=None, |
| n_samples=200, top_k=None, temperature=1.0, |
| temperature_map=None, summary=None): |
| """ |
| Full inference: classify mechanism, then sample parameters. |
| |
| Args: |
| temperature: scalar fallback (>1 broadens posteriors) |
| temperature_map: dict mapping mechanism name -> list of |
| per-parameter temperatures. Overrides scalar temperature. |
| summary: [B, 21] hand-crafted summary stats (summary mode only) |
| |
| Returns: |
| dict with mechanism_probs, mechanism_xdB, mechanism_pred, samples, stats |
| """ |
| context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, summary=summary) |
| logits = self.classifier(context) |
| probs = F.softmax(logits, dim=-1) |
| pred = probs.argmax(dim=-1) |
|
|
| probs_clamped = probs.clamp(min=1e-7, max=1 - 1e-7) |
| xdB = 10.0 * torch.log10(probs_clamped / (1.0 - probs_clamped)) |
|
|
| samples_dict = {} |
| stats_dict = {} |
|
|
| for m_idx, mech in enumerate(TPD_MECHANISM_LIST): |
| if top_k is not None: |
| top_k_mechs = probs.topk(top_k, dim=-1).indices |
| if not (top_k_mechs == m_idx).any(): |
| samples_dict[mech] = None |
| stats_dict[mech] = None |
| continue |
|
|
| T = temperature |
| if temperature_map is not None and mech in temperature_map: |
| T = torch.tensor(temperature_map[mech], dtype=torch.float32) |
|
|
| s = self.flow_heads[mech].sample(context, n_samples=n_samples, |
| temperature=T) |
| samples_dict[mech] = s |
| stats_dict[mech] = { |
| 'mean': s.mean(dim=1), |
| 'std': s.std(dim=1), |
| 'median': s.median(dim=1).values, |
| 'q05': s.quantile(0.05, dim=1), |
| 'q95': s.quantile(0.95, dim=1), |
| } |
|
|
| return { |
| 'mechanism_probs': probs, |
| 'mechanism_xdB': xdB, |
| 'mechanism_pred': pred, |
| 'samples': samples_dict, |
| 'stats': stats_dict, |
| } |
|
|
| def predict_single_mechanism(self, x, mechanism, scan_mask=None, |
| sigmas=None, flux_scales=None, n_samples=1000, |
| temperature=1.0, temperature_map=None, |
| summary=None): |
| """Sample parameters assuming a known mechanism.""" |
| context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, summary=summary) |
| T = temperature |
| if temperature_map is not None and mechanism in temperature_map: |
| T = torch.tensor(temperature_map[mechanism], dtype=torch.float32) |
| samples = self.flow_heads[mechanism].sample(context, n_samples=n_samples, |
| temperature=T) |
| return { |
| 'mean': samples.mean(dim=1), |
| 'std': samples.std(dim=1), |
| 'median': samples.median(dim=1).values, |
| 'q05': samples.quantile(0.05, dim=1), |
| 'q95': samples.quantile(0.95, dim=1), |
| 'samples': samples, |
| } |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| if __name__ == "__main__": |
| n_mechs = len(TPD_MECHANISM_LIST) |
| B, N_beta, T = n_mechs, 3, 500 |
|
|
| x = torch.randn(B, N_beta, 2, T) |
| scan_mask = torch.ones(B, N_beta, T, dtype=torch.bool) |
| sigmas = torch.randn(B, N_beta) |
| flux_scales = torch.randn(B, N_beta) |
| mechanism_ids = torch.arange(n_mechs) |
|
|
| max_dim = max(TPD_MECHANISM_PARAMS[m]['dim'] for m in TPD_MECHANISM_LIST) |
| mech_theta = torch.randn(B, max_dim) |
| mech_theta_mask = torch.zeros(B, max_dim, dtype=torch.bool) |
| for i, mid in enumerate(mechanism_ids): |
| d = TPD_MECHANISM_PARAMS[TPD_MECHANISM_LIST[mid]]['dim'] |
| mech_theta_mask[i, :d] = True |
|
|
| print("=" * 60) |
| print("Testing MultiMechanismFlowTPD (multi-heating-rate, Set Transformer)") |
| print("=" * 60) |
|
|
| model = MultiMechanismFlowTPD( |
| d_context=128, |
| d_model=128, |
| n_coupling_layers=8, |
| hidden_dim=128, |
| coupling_type='affine', |
| ) |
|
|
| total_params = count_parameters(model) |
| print(f"Total parameters: {total_params:,}") |
| print(f" Encoder: {count_parameters(model.encoder):,}") |
| print(f" Classifier: {count_parameters(model.classifier):,}") |
| for mech in TPD_MECHANISM_LIST: |
| print(f" Flow ({mech}, dim={TPD_MECHANISM_PARAMS[mech]['dim']}): " |
| f"{count_parameters(model.flow_heads[mech]):,}") |
|
|
| out = model(x, mechanism_ids, mech_theta, mech_theta_mask, |
| scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales) |
| print(f"\nForward pass:") |
| print(f" Logits shape: {out['logits'].shape}") |
| print(f" NLL shape: {out['nll'].shape}") |
| print(f" NLL values: {out['nll']}") |
|
|
| pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, n_samples=100) |
| print(f"\nPrediction:") |
| print(f" Mechanism probs shape: {pred['mechanism_probs'].shape}") |
| print(f" Mechanism xdB shape: {pred['mechanism_xdB'].shape}") |
| print(f" Predicted mechanisms: {pred['mechanism_pred']}") |
| for mech in TPD_MECHANISM_LIST: |
| if pred['samples'][mech] is not None: |
| print(f" {mech} samples shape: {pred['samples'][mech].shape}") |
|
|