ecflow / tpd_model.py
Bing Yan
Initial ECFlow deployment
d6b782e
"""
Multi-Mechanism Normalizing Flow for Joint Mechanism Identification
and Bayesian Parameter Inference from Multi-Heating-Rate TPD Signals.
Architecture mirrors the electrochemistry model (multi_mechanism_model.py)
but configured for TPD with 2 input channels (temperature, rate) and
6 catalysis mechanisms.
Reuses domain-agnostic components from flow_model.py and
multi_mechanism_model.py (SAB, PMA, MechanismClassifier, MechanismFlow).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from flow_model import (
SignalEncoder,
ActNorm,
ConditionalSplineCoupling,
ConditionalAffineCoupling,
)
from multi_mechanism_model import MechanismClassifier, MechanismFlow, SAB, PMA, SummaryProjection, SUMMARY_DIM
from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS
class MultiScanEncoderTPD(nn.Module):
"""
Encode a set of multi-heating-rate TPD curves into a single context vector.
Architecture (Set Transformer):
1. Shared per-curve CNN encoder -> per-curve embedding
2. Augment with [log10(heating_rate), log10(peak_rate)]
3. SAB: self-attention across heating rates
4. PMA: attention-based pooling to single vector
5. rho MLP: project to final context
Input: x [B, N_beta, 2, T], scan_mask [B, N_beta, T],
heating_rates [B, N_beta], rate_scales [B, N_beta]
Output: context [B, d_context]
"""
def __init__(self, in_channels=2, d_model=128, d_context=128, n_heads=4):
super().__init__()
self.per_cv_encoder = SignalEncoder(
in_channels=in_channels, d_model=d_model, d_context=d_context,
)
self.cv_augment = nn.Sequential(
nn.Linear(d_context + 2, d_context),
nn.GELU(),
)
self.sab = SAB(d_context, n_heads=n_heads)
self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1)
self.rho = nn.Sequential(
nn.Linear(d_context, d_context),
nn.GELU(),
nn.Linear(d_context, d_context),
)
def forward(self, x, scan_mask=None, sigmas=None, flux_scales=None):
"""
Args:
x: [B, N_beta, 2, T] multi-heating-rate TPD curves
scan_mask: [B, N_beta, T] valid timestep mask
sigmas: [B, N_beta] log10 heating rates
flux_scales: [B, N_beta] log10(peak_rate) per curve
Returns:
context: [B, d_context]
"""
B, N, C, T = x.shape
x_flat = x.reshape(B * N, C, T)
mask_flat = scan_mask.reshape(B * N, T) if scan_mask is not None else None
h_flat = self.per_cv_encoder(x_flat, mask=mask_flat)
h = h_flat.reshape(B, N, -1)
if sigmas is None:
sigmas = torch.zeros(B, N, device=x.device)
if flux_scales is None:
flux_scales = torch.zeros(B, N, device=x.device)
aug_features = torch.stack([sigmas, flux_scales], dim=-1)
h = self.cv_augment(torch.cat([h, aug_features], dim=-1))
if scan_mask is not None:
cv_invalid = ~scan_mask.any(dim=-1) # [B, N] True = padded
else:
cv_invalid = None
h = self.sab(h, key_padding_mask=cv_invalid)
h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context]
h = h.squeeze(1)
context = self.rho(h)
return context
class MultiMechanismFlowTPD(nn.Module):
"""
Joint mechanism identification and parameter inference model for TPD.
Combines:
- Multi-heating-rate signal encoder (Set Transformer over per-curve embeddings)
- Mechanism classifier (6 TPD mechanisms)
- Per-mechanism normalizing flow heads
If use_summary_features=True, replaces the signal encoder with a simple
MLP projection from hand-crafted summary statistics (21-dim) to context
space, keeping all other components identical.
"""
def __init__(
self,
d_context=128,
d_model=128,
n_coupling_layers=6,
hidden_dim=96,
coupling_type='spline',
n_bins=8,
tail_bound=5.0,
use_summary_features=False,
):
super().__init__()
self.n_mechanisms = len(TPD_MECHANISM_LIST)
self.mechanism_list = TPD_MECHANISM_LIST
self.d_context = d_context
self.use_summary_features = use_summary_features
if use_summary_features:
self.summary_proj = SummaryProjection(
summary_dim=SUMMARY_DIM, d_context=d_context,
)
self.encoder = None
else:
self.encoder = MultiScanEncoderTPD(
in_channels=2, d_model=d_model, d_context=d_context,
)
self.summary_proj = None
self.classifier = MechanismClassifier(
d_context=d_context,
n_mechanisms=self.n_mechanisms,
hidden_dim=hidden_dim,
)
self.flow_heads = nn.ModuleDict()
for mech in TPD_MECHANISM_LIST:
theta_dim = TPD_MECHANISM_PARAMS[mech]['dim']
self.flow_heads[mech] = MechanismFlow(
theta_dim=theta_dim,
d_context=d_context,
n_coupling_layers=n_coupling_layers,
hidden_dim=hidden_dim,
coupling_type=coupling_type,
n_bins=n_bins,
tail_bound=tail_bound,
)
def set_theta_stats(self, mechanism, mean, std):
"""Set normalization stats for a specific mechanism's flow head."""
self.flow_heads[mechanism].set_theta_stats(mean, std)
def encode_signal(self, x, scan_mask=None, sigmas=None, flux_scales=None,
summary=None):
if self.use_summary_features:
assert summary is not None, "summary features required in summary mode"
return self.summary_proj(summary)
return self.encoder(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales)
def forward(self, x, mechanism_ids, mech_theta, mech_theta_mask=None,
scan_mask=None, sigmas=None, flux_scales=None, summary=None):
"""
Compute classification logits and per-sample NLL for the true mechanism.
Returns:
dict with 'logits' [B, n_mechanisms] and 'nll' [B]
"""
context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, summary=summary)
logits = self.classifier(context)
nll = torch.zeros(x.shape[0], device=x.device)
for m_idx, mech in enumerate(TPD_MECHANISM_LIST):
sel = (mechanism_ids == m_idx)
if not sel.any():
continue
theta_dim = TPD_MECHANISM_PARAMS[mech]['dim']
ctx_m = context[sel]
theta_m = mech_theta[sel, :theta_dim]
log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m)
bad = ~torch.isfinite(log_p)
if bad.any():
log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p)
nll[sel] = -log_p
return {'logits': logits, 'nll': nll}
def forward_with_calibration(self, x, mechanism_ids, mech_theta,
mech_theta_mask=None, scan_mask=None,
sigmas=None, flux_scales=None,
cal_n_samples=64, cal_levels=(0.5, 0.9),
cal_beta=20.0, summary=None):
"""Forward pass with additional calibration loss (see MultiMechanismFlow)."""
context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, summary=summary)
logits = self.classifier(context)
nll = torch.zeros(x.shape[0], device=x.device)
cal_losses = []
for m_idx, mech in enumerate(TPD_MECHANISM_LIST):
sel = (mechanism_ids == m_idx)
if not sel.any():
continue
theta_dim = TPD_MECHANISM_PARAMS[mech]['dim']
ctx_m = context[sel]
theta_m = mech_theta[sel, :theta_dim]
log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m)
bad = ~torch.isfinite(log_p)
if bad.any():
log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p)
nll[sel] = -log_p
if ctx_m.shape[0] < 4:
continue
samples = self.flow_heads[mech].sample_with_grad(
ctx_m, n_samples=cal_n_samples,
)
# Inverse-spread weights: collapsed parameters (small posterior std)
# get more calibration pressure than well-spread parameters.
with torch.no_grad():
param_std = samples.std(dim=1).clamp(min=1e-4) # [B_m, D]
inv_spread_w = 1.0 / param_std # [B_m, D]
inv_spread_w = inv_spread_w / inv_spread_w.mean()
for level in cal_levels:
alpha = (1.0 - level) / 2.0
lower = torch.quantile(samples, alpha, dim=1) # [B_m, D]
upper = torch.quantile(samples, 1 - alpha, dim=1) # [B_m, D]
inside = (
torch.sigmoid(cal_beta * (theta_m - lower))
* torch.sigmoid(cal_beta * (upper - theta_m))
) # [B_m, D]
per_sample_loss = (inside - level).pow(2) # [B_m, D]
cal_losses.append((per_sample_loss * inv_spread_w).mean())
if cal_losses:
cal_loss = torch.stack(cal_losses).mean()
else:
cal_loss = torch.tensor(0.0, device=x.device)
return {'logits': logits, 'nll': nll, 'cal_loss': cal_loss}
@torch.no_grad()
def predict(self, x, scan_mask=None, sigmas=None, flux_scales=None,
n_samples=200, top_k=None, temperature=1.0,
temperature_map=None, summary=None):
"""
Full inference: classify mechanism, then sample parameters.
Args:
temperature: scalar fallback (>1 broadens posteriors)
temperature_map: dict mapping mechanism name -> list of
per-parameter temperatures. Overrides scalar temperature.
summary: [B, 21] hand-crafted summary stats (summary mode only)
Returns:
dict with mechanism_probs, mechanism_xdB, mechanism_pred, samples, stats
"""
context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, summary=summary)
logits = self.classifier(context)
probs = F.softmax(logits, dim=-1)
pred = probs.argmax(dim=-1)
probs_clamped = probs.clamp(min=1e-7, max=1 - 1e-7)
xdB = 10.0 * torch.log10(probs_clamped / (1.0 - probs_clamped))
samples_dict = {}
stats_dict = {}
for m_idx, mech in enumerate(TPD_MECHANISM_LIST):
if top_k is not None:
top_k_mechs = probs.topk(top_k, dim=-1).indices
if not (top_k_mechs == m_idx).any():
samples_dict[mech] = None
stats_dict[mech] = None
continue
T = temperature
if temperature_map is not None and mech in temperature_map:
T = torch.tensor(temperature_map[mech], dtype=torch.float32)
s = self.flow_heads[mech].sample(context, n_samples=n_samples,
temperature=T)
samples_dict[mech] = s
stats_dict[mech] = {
'mean': s.mean(dim=1),
'std': s.std(dim=1),
'median': s.median(dim=1).values,
'q05': s.quantile(0.05, dim=1),
'q95': s.quantile(0.95, dim=1),
}
return {
'mechanism_probs': probs,
'mechanism_xdB': xdB,
'mechanism_pred': pred,
'samples': samples_dict,
'stats': stats_dict,
}
def predict_single_mechanism(self, x, mechanism, scan_mask=None,
sigmas=None, flux_scales=None, n_samples=1000,
temperature=1.0, temperature_map=None,
summary=None):
"""Sample parameters assuming a known mechanism."""
context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, summary=summary)
T = temperature
if temperature_map is not None and mechanism in temperature_map:
T = torch.tensor(temperature_map[mechanism], dtype=torch.float32)
samples = self.flow_heads[mechanism].sample(context, n_samples=n_samples,
temperature=T)
return {
'mean': samples.mean(dim=1),
'std': samples.std(dim=1),
'median': samples.median(dim=1).values,
'q05': samples.quantile(0.05, dim=1),
'q95': samples.quantile(0.95, dim=1),
'samples': samples,
}
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
n_mechs = len(TPD_MECHANISM_LIST)
B, N_beta, T = n_mechs, 3, 500
x = torch.randn(B, N_beta, 2, T)
scan_mask = torch.ones(B, N_beta, T, dtype=torch.bool)
sigmas = torch.randn(B, N_beta)
flux_scales = torch.randn(B, N_beta)
mechanism_ids = torch.arange(n_mechs)
max_dim = max(TPD_MECHANISM_PARAMS[m]['dim'] for m in TPD_MECHANISM_LIST)
mech_theta = torch.randn(B, max_dim)
mech_theta_mask = torch.zeros(B, max_dim, dtype=torch.bool)
for i, mid in enumerate(mechanism_ids):
d = TPD_MECHANISM_PARAMS[TPD_MECHANISM_LIST[mid]]['dim']
mech_theta_mask[i, :d] = True
print("=" * 60)
print("Testing MultiMechanismFlowTPD (multi-heating-rate, Set Transformer)")
print("=" * 60)
model = MultiMechanismFlowTPD(
d_context=128,
d_model=128,
n_coupling_layers=8,
hidden_dim=128,
coupling_type='affine',
)
total_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")
print(f" Encoder: {count_parameters(model.encoder):,}")
print(f" Classifier: {count_parameters(model.classifier):,}")
for mech in TPD_MECHANISM_LIST:
print(f" Flow ({mech}, dim={TPD_MECHANISM_PARAMS[mech]['dim']}): "
f"{count_parameters(model.flow_heads[mech]):,}")
out = model(x, mechanism_ids, mech_theta, mech_theta_mask,
scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales)
print(f"\nForward pass:")
print(f" Logits shape: {out['logits'].shape}")
print(f" NLL shape: {out['nll'].shape}")
print(f" NLL values: {out['nll']}")
pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, n_samples=100)
print(f"\nPrediction:")
print(f" Mechanism probs shape: {pred['mechanism_probs'].shape}")
print(f" Mechanism xdB shape: {pred['mechanism_xdB'].shape}")
print(f" Predicted mechanisms: {pred['mechanism_pred']}")
for mech in TPD_MECHANISM_LIST:
if pred['samples'][mech] is not None:
print(f" {mech} samples shape: {pred['samples'][mech].shape}")