""" Multi-Mechanism Normalizing Flow for Joint Mechanism Identification and Bayesian Parameter Inference from Multi-Heating-Rate TPD Signals. Architecture mirrors the electrochemistry model (multi_mechanism_model.py) but configured for TPD with 2 input channels (temperature, rate) and 6 catalysis mechanisms. Reuses domain-agnostic components from flow_model.py and multi_mechanism_model.py (SAB, PMA, MechanismClassifier, MechanismFlow). """ import torch import torch.nn as nn import torch.nn.functional as F import math from flow_model import ( SignalEncoder, ActNorm, ConditionalSplineCoupling, ConditionalAffineCoupling, ) from multi_mechanism_model import MechanismClassifier, MechanismFlow, SAB, PMA, SummaryProjection, SUMMARY_DIM from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS class MultiScanEncoderTPD(nn.Module): """ Encode a set of multi-heating-rate TPD curves into a single context vector. Architecture (Set Transformer): 1. Shared per-curve CNN encoder -> per-curve embedding 2. Augment with [log10(heating_rate), log10(peak_rate)] 3. SAB: self-attention across heating rates 4. PMA: attention-based pooling to single vector 5. rho MLP: project to final context Input: x [B, N_beta, 2, T], scan_mask [B, N_beta, T], heating_rates [B, N_beta], rate_scales [B, N_beta] Output: context [B, d_context] """ def __init__(self, in_channels=2, d_model=128, d_context=128, n_heads=4): super().__init__() self.per_cv_encoder = SignalEncoder( in_channels=in_channels, d_model=d_model, d_context=d_context, ) self.cv_augment = nn.Sequential( nn.Linear(d_context + 2, d_context), nn.GELU(), ) self.sab = SAB(d_context, n_heads=n_heads) self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1) self.rho = nn.Sequential( nn.Linear(d_context, d_context), nn.GELU(), nn.Linear(d_context, d_context), ) def forward(self, x, scan_mask=None, sigmas=None, flux_scales=None): """ Args: x: [B, N_beta, 2, T] multi-heating-rate TPD curves scan_mask: [B, N_beta, T] valid timestep mask sigmas: [B, N_beta] log10 heating rates flux_scales: [B, N_beta] log10(peak_rate) per curve Returns: context: [B, d_context] """ B, N, C, T = x.shape x_flat = x.reshape(B * N, C, T) mask_flat = scan_mask.reshape(B * N, T) if scan_mask is not None else None h_flat = self.per_cv_encoder(x_flat, mask=mask_flat) h = h_flat.reshape(B, N, -1) if sigmas is None: sigmas = torch.zeros(B, N, device=x.device) if flux_scales is None: flux_scales = torch.zeros(B, N, device=x.device) aug_features = torch.stack([sigmas, flux_scales], dim=-1) h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) if scan_mask is not None: cv_invalid = ~scan_mask.any(dim=-1) # [B, N] True = padded else: cv_invalid = None h = self.sab(h, key_padding_mask=cv_invalid) h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context] h = h.squeeze(1) context = self.rho(h) return context class MultiMechanismFlowTPD(nn.Module): """ Joint mechanism identification and parameter inference model for TPD. Combines: - Multi-heating-rate signal encoder (Set Transformer over per-curve embeddings) - Mechanism classifier (6 TPD mechanisms) - Per-mechanism normalizing flow heads If use_summary_features=True, replaces the signal encoder with a simple MLP projection from hand-crafted summary statistics (21-dim) to context space, keeping all other components identical. """ def __init__( self, d_context=128, d_model=128, n_coupling_layers=6, hidden_dim=96, coupling_type='spline', n_bins=8, tail_bound=5.0, use_summary_features=False, ): super().__init__() self.n_mechanisms = len(TPD_MECHANISM_LIST) self.mechanism_list = TPD_MECHANISM_LIST self.d_context = d_context self.use_summary_features = use_summary_features if use_summary_features: self.summary_proj = SummaryProjection( summary_dim=SUMMARY_DIM, d_context=d_context, ) self.encoder = None else: self.encoder = MultiScanEncoderTPD( in_channels=2, d_model=d_model, d_context=d_context, ) self.summary_proj = None self.classifier = MechanismClassifier( d_context=d_context, n_mechanisms=self.n_mechanisms, hidden_dim=hidden_dim, ) self.flow_heads = nn.ModuleDict() for mech in TPD_MECHANISM_LIST: theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] self.flow_heads[mech] = MechanismFlow( theta_dim=theta_dim, d_context=d_context, n_coupling_layers=n_coupling_layers, hidden_dim=hidden_dim, coupling_type=coupling_type, n_bins=n_bins, tail_bound=tail_bound, ) def set_theta_stats(self, mechanism, mean, std): """Set normalization stats for a specific mechanism's flow head.""" self.flow_heads[mechanism].set_theta_stats(mean, std) def encode_signal(self, x, scan_mask=None, sigmas=None, flux_scales=None, summary=None): if self.use_summary_features: assert summary is not None, "summary features required in summary mode" return self.summary_proj(summary) return self.encoder(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales) def forward(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, scan_mask=None, sigmas=None, flux_scales=None, summary=None): """ Compute classification logits and per-sample NLL for the true mechanism. Returns: dict with 'logits' [B, n_mechanisms] and 'nll' [B] """ context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary) logits = self.classifier(context) nll = torch.zeros(x.shape[0], device=x.device) for m_idx, mech in enumerate(TPD_MECHANISM_LIST): sel = (mechanism_ids == m_idx) if not sel.any(): continue theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] ctx_m = context[sel] theta_m = mech_theta[sel, :theta_dim] log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m) bad = ~torch.isfinite(log_p) if bad.any(): log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p) nll[sel] = -log_p return {'logits': logits, 'nll': nll} def forward_with_calibration(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, scan_mask=None, sigmas=None, flux_scales=None, cal_n_samples=64, cal_levels=(0.5, 0.9), cal_beta=20.0, summary=None): """Forward pass with additional calibration loss (see MultiMechanismFlow).""" context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary) logits = self.classifier(context) nll = torch.zeros(x.shape[0], device=x.device) cal_losses = [] for m_idx, mech in enumerate(TPD_MECHANISM_LIST): sel = (mechanism_ids == m_idx) if not sel.any(): continue theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] ctx_m = context[sel] theta_m = mech_theta[sel, :theta_dim] log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m) bad = ~torch.isfinite(log_p) if bad.any(): log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p) nll[sel] = -log_p if ctx_m.shape[0] < 4: continue samples = self.flow_heads[mech].sample_with_grad( ctx_m, n_samples=cal_n_samples, ) # Inverse-spread weights: collapsed parameters (small posterior std) # get more calibration pressure than well-spread parameters. with torch.no_grad(): param_std = samples.std(dim=1).clamp(min=1e-4) # [B_m, D] inv_spread_w = 1.0 / param_std # [B_m, D] inv_spread_w = inv_spread_w / inv_spread_w.mean() for level in cal_levels: alpha = (1.0 - level) / 2.0 lower = torch.quantile(samples, alpha, dim=1) # [B_m, D] upper = torch.quantile(samples, 1 - alpha, dim=1) # [B_m, D] inside = ( torch.sigmoid(cal_beta * (theta_m - lower)) * torch.sigmoid(cal_beta * (upper - theta_m)) ) # [B_m, D] per_sample_loss = (inside - level).pow(2) # [B_m, D] cal_losses.append((per_sample_loss * inv_spread_w).mean()) if cal_losses: cal_loss = torch.stack(cal_losses).mean() else: cal_loss = torch.tensor(0.0, device=x.device) return {'logits': logits, 'nll': nll, 'cal_loss': cal_loss} @torch.no_grad() def predict(self, x, scan_mask=None, sigmas=None, flux_scales=None, n_samples=200, top_k=None, temperature=1.0, temperature_map=None, summary=None): """ Full inference: classify mechanism, then sample parameters. Args: temperature: scalar fallback (>1 broadens posteriors) temperature_map: dict mapping mechanism name -> list of per-parameter temperatures. Overrides scalar temperature. summary: [B, 21] hand-crafted summary stats (summary mode only) Returns: dict with mechanism_probs, mechanism_xdB, mechanism_pred, samples, stats """ context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary) logits = self.classifier(context) probs = F.softmax(logits, dim=-1) pred = probs.argmax(dim=-1) probs_clamped = probs.clamp(min=1e-7, max=1 - 1e-7) xdB = 10.0 * torch.log10(probs_clamped / (1.0 - probs_clamped)) samples_dict = {} stats_dict = {} for m_idx, mech in enumerate(TPD_MECHANISM_LIST): if top_k is not None: top_k_mechs = probs.topk(top_k, dim=-1).indices if not (top_k_mechs == m_idx).any(): samples_dict[mech] = None stats_dict[mech] = None continue T = temperature if temperature_map is not None and mech in temperature_map: T = torch.tensor(temperature_map[mech], dtype=torch.float32) s = self.flow_heads[mech].sample(context, n_samples=n_samples, temperature=T) samples_dict[mech] = s stats_dict[mech] = { 'mean': s.mean(dim=1), 'std': s.std(dim=1), 'median': s.median(dim=1).values, 'q05': s.quantile(0.05, dim=1), 'q95': s.quantile(0.95, dim=1), } return { 'mechanism_probs': probs, 'mechanism_xdB': xdB, 'mechanism_pred': pred, 'samples': samples_dict, 'stats': stats_dict, } def predict_single_mechanism(self, x, mechanism, scan_mask=None, sigmas=None, flux_scales=None, n_samples=1000, temperature=1.0, temperature_map=None, summary=None): """Sample parameters assuming a known mechanism.""" context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary) T = temperature if temperature_map is not None and mechanism in temperature_map: T = torch.tensor(temperature_map[mechanism], dtype=torch.float32) samples = self.flow_heads[mechanism].sample(context, n_samples=n_samples, temperature=T) return { 'mean': samples.mean(dim=1), 'std': samples.std(dim=1), 'median': samples.median(dim=1).values, 'q05': samples.quantile(0.05, dim=1), 'q95': samples.quantile(0.95, dim=1), 'samples': samples, } def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) if __name__ == "__main__": n_mechs = len(TPD_MECHANISM_LIST) B, N_beta, T = n_mechs, 3, 500 x = torch.randn(B, N_beta, 2, T) scan_mask = torch.ones(B, N_beta, T, dtype=torch.bool) sigmas = torch.randn(B, N_beta) flux_scales = torch.randn(B, N_beta) mechanism_ids = torch.arange(n_mechs) max_dim = max(TPD_MECHANISM_PARAMS[m]['dim'] for m in TPD_MECHANISM_LIST) mech_theta = torch.randn(B, max_dim) mech_theta_mask = torch.zeros(B, max_dim, dtype=torch.bool) for i, mid in enumerate(mechanism_ids): d = TPD_MECHANISM_PARAMS[TPD_MECHANISM_LIST[mid]]['dim'] mech_theta_mask[i, :d] = True print("=" * 60) print("Testing MultiMechanismFlowTPD (multi-heating-rate, Set Transformer)") print("=" * 60) model = MultiMechanismFlowTPD( d_context=128, d_model=128, n_coupling_layers=8, hidden_dim=128, coupling_type='affine', ) total_params = count_parameters(model) print(f"Total parameters: {total_params:,}") print(f" Encoder: {count_parameters(model.encoder):,}") print(f" Classifier: {count_parameters(model.classifier):,}") for mech in TPD_MECHANISM_LIST: print(f" Flow ({mech}, dim={TPD_MECHANISM_PARAMS[mech]['dim']}): " f"{count_parameters(model.flow_heads[mech]):,}") out = model(x, mechanism_ids, mech_theta, mech_theta_mask, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales) print(f"\nForward pass:") print(f" Logits shape: {out['logits'].shape}") print(f" NLL shape: {out['nll'].shape}") print(f" NLL values: {out['nll']}") pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, n_samples=100) print(f"\nPrediction:") print(f" Mechanism probs shape: {pred['mechanism_probs'].shape}") print(f" Mechanism xdB shape: {pred['mechanism_xdB'].shape}") print(f" Predicted mechanisms: {pred['mechanism_pred']}") for mech in TPD_MECHANISM_LIST: if pred['samples'][mech] is not None: print(f" {mech} samples shape: {pred['samples'][mech].shape}")