""" Conditional Normalizing Flow for Electrochemical Parameter Inference. Models p(theta | x) where: theta = simulator parameters (variable dimension per mechanism) x = (E(t), j(t), t) -- observed electrochemical signal [3, T] Supports two coupling layer types: - Affine: simple scale+shift (original RealNVP) - Spline: rational-quadratic spline (Neural Spline Flows, Durkan et al. 2019) Architecture: 1. SignalEncoder: 1D CNN + global pooling -> fixed-size context vector 2. ConditionalFlow: Coupling layers (affine or spline) conditioned on context Training objective: Negative log-likelihood (NLL) L = -E[log q_phi(theta | x)] = -E[log p_z(f^{-1}(theta; x)) + log |det df^{-1}/d_theta|] """ import torch import torch.nn as nn import torch.nn.functional as F import math # ============================================================================= # Signal Encoder: x -> context vector # ============================================================================= class SignalEncoder(nn.Module): """ Encode variable-length electrochemical waveform into a fixed-size context vector. Architecture: 1D CNN -> Global Average Pooling -> MLP Input: [B, 3, T] (potential, flux, time) Output: [B, d_context] """ def __init__(self, in_channels=3, d_model=128, d_context=128): super().__init__() self.conv = nn.Sequential( nn.Conv1d(in_channels, d_model // 2, kernel_size=7, padding=3), nn.GELU(), nn.Conv1d(d_model // 2, d_model, kernel_size=5, padding=2), nn.GELU(), nn.Conv1d(d_model, d_model, kernel_size=3, padding=1), nn.GELU(), ) self.pool_proj = nn.Sequential( nn.Linear(d_model, d_context), nn.GELU(), nn.Linear(d_context, d_context), ) def forward(self, x, mask=None): h = self.conv(x) if mask is not None: mask_expanded = mask.unsqueeze(1).float() h = (h * mask_expanded).sum(dim=-1) / mask_expanded.sum(dim=-1).clamp(min=1) else: h = h.mean(dim=-1) context = self.pool_proj(h) return context # ============================================================================= # Normalizing Flow Components # ============================================================================= class ActNorm(nn.Module): """Activation normalization (from Glow) with data-dependent init.""" def __init__(self, dim): super().__init__() self.log_scale = nn.Parameter(torch.zeros(dim)) self.bias = nn.Parameter(torch.zeros(dim)) self.register_buffer('_initialized', torch.tensor(False)) @property def initialized(self): return bool(self._initialized.item()) @initialized.setter def initialized(self, value): self._initialized.fill_(value) def initialize(self, x): with torch.no_grad(): self.bias.data = -x.mean(dim=0) if x.shape[0] > 1: std = x.std(dim=0).clamp(min=0.1) self.log_scale.data = -torch.log(std) else: self.log_scale.data.zero_() self.initialized = True def forward(self, x): if not self.initialized: self.initialize(x) y = (x + self.bias) * torch.exp(self.log_scale) log_det = self.log_scale.sum() return y, log_det def inverse(self, y): x = y * torch.exp(-self.log_scale) - self.bias return x class ConditionalAffineCoupling(nn.Module): """ Conditional affine coupling layer. Forward (z -> theta): theta_b = z_b * exp(s) + t Inverse (theta -> z): z_b = (theta_b - t) * exp(-s) The log-scale s is soft-clamped to [-s_clamp, s_clamp]. With s_clamp=2.0 each layer can scale by up to exp(2)≈7.4x per dimension, giving the flow enough dynamic range to produce both narrow (identifiable) and wide (non-identifiable) posteriors. """ def __init__(self, dim, d_context, hidden_dim=128, mask_type='even', s_clamp=2.0): super().__init__() self.dim = dim self.s_clamp = s_clamp if mask_type == 'even': self.register_buffer('mask', torch.arange(dim) % 2 == 0) else: self.register_buffer('mask', torch.arange(dim) % 2 == 1) n_a = self.mask.sum().item() n_b = dim - n_a self.net = nn.Sequential( nn.Linear(n_a + d_context, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, 2 * n_b), ) nn.init.zeros_(self.net[-1].weight) nn.init.zeros_(self.net[-1].bias) def _clamp_s(self, s_raw): """Soft-clamp log-scale using 2*tanh(s/2) for smooth gradients everywhere.""" return self.s_clamp * torch.tanh(s_raw / self.s_clamp) def forward(self, z, context): z_a = z[:, self.mask] z_b = z[:, ~self.mask] st = self.net(torch.cat([z_a, context], dim=-1)) s, t = st.chunk(2, dim=-1) s = self._clamp_s(s) theta_b = z_b * torch.exp(s) + t log_det = s.sum(dim=-1) theta = torch.empty_like(z) theta[:, self.mask] = z_a theta[:, ~self.mask] = theta_b return theta, log_det def inverse(self, theta, context): theta_a = theta[:, self.mask] theta_b = theta[:, ~self.mask] st = self.net(torch.cat([theta_a, context], dim=-1)) s, t = st.chunk(2, dim=-1) s = self._clamp_s(s) z_b = (theta_b - t) * torch.exp(-s) log_det = -s.sum(dim=-1) z = torch.empty_like(theta) z[:, self.mask] = theta_a z[:, ~self.mask] = z_b return z, log_det # ============================================================================= # Rational-Quadratic Spline Transform (Durkan et al. 2019) # ============================================================================= MIN_BIN_FRACTION = 1e-2 # each bin gets at least 1% of the total width/height MIN_DERIVATIVE = 1e-2 def _prepare_spline_params(widths, heights, derivatives, tail_bound): """Shared preprocessing for forward and inverse spline transforms. Enforces minimum bin width/height (following nflows convention) to prevent degenerate near-step-function splines that break invertibility. """ K = widths.shape[-1] total = 2 * tail_bound widths = F.softmax(widths, dim=-1) widths = MIN_BIN_FRACTION + (1 - K * MIN_BIN_FRACTION) * widths widths = widths * total heights = F.softmax(heights, dim=-1) heights = MIN_BIN_FRACTION + (1 - K * MIN_BIN_FRACTION) * heights heights = heights * total derivatives = F.softplus(derivatives) + MIN_DERIVATIVE cumwidths = torch.cumsum(widths, dim=-1) cumwidths = F.pad(cumwidths, (1, 0), value=0.0) - tail_bound cumheights = torch.cumsum(heights, dim=-1) cumheights = F.pad(cumheights, (1, 0), value=0.0) - tail_bound return widths, heights, derivatives, cumwidths, cumheights def rational_quadratic_spline_forward(x, widths, heights, derivatives, tail_bound=5.0): """ Apply monotonic rational-quadratic spline transform (forward direction). Uses identity transform outside [-tail_bound, tail_bound]. Args: x: [B, D] input values widths: [B, D, K] bin widths (pre-softmax) heights: [B, D, K] bin heights (pre-softmax) derivatives: [B, D, K+1] knot derivatives (pre-softplus) Returns: y: [B, D] transformed values log_det: [B, D] log absolute derivative per dimension """ K = widths.shape[-1] widths, heights, derivatives, cumwidths, cumheights = \ _prepare_spline_params(widths, heights, derivatives, tail_bound) inside = (x >= -tail_bound) & (x <= tail_bound) # Default: identity (for tails) y = x.clone() log_det = torch.zeros_like(x) if not inside.any(): return y, log_det x_in = x[inside] # Bin lookup on the interior cumwidths # Flatten the relevant cumwidths for searchsorted cw_in = cumwidths[inside] # [N_inside, K+1] bin_idx = torch.searchsorted(cw_in[:, 1:].contiguous(), x_in.unsqueeze(-1)).squeeze(-1) bin_idx = bin_idx.clamp(0, K - 1) idx = bin_idx.unsqueeze(-1) w_k = widths[inside].gather(-1, idx).squeeze(-1) h_k = heights[inside].gather(-1, idx).squeeze(-1) d_k = derivatives[inside].gather(-1, idx).squeeze(-1) d_k1 = derivatives[inside].gather(-1, idx + 1).squeeze(-1) cw_k = cw_in.gather(-1, idx).squeeze(-1) ch_k = cumheights[inside].gather(-1, idx).squeeze(-1) xi = (x_in - cw_k) / w_k xi = xi.clamp(1e-6, 1.0 - 1e-6) s_k = h_k / w_k numer = h_k * (s_k * xi.pow(2) + d_k * xi * (1 - xi)) denom = s_k + (d_k + d_k1 - 2 * s_k) * xi * (1 - xi) y[inside] = ch_k + numer / denom deriv_numer = s_k.pow(2) * (d_k1 * xi.pow(2) + 2 * s_k * xi * (1 - xi) + d_k * (1 - xi).pow(2)) log_det[inside] = torch.log(deriv_numer.clamp(min=1e-8)) - 2 * torch.log(denom.abs().clamp(min=1e-8)) log_det = log_det.clamp(-20.0, 20.0) return y, log_det def rational_quadratic_spline_inverse(y, widths, heights, derivatives, tail_bound=5.0): """ Apply inverse rational-quadratic spline transform. Uses identity transform outside [-tail_bound, tail_bound]. """ K = widths.shape[-1] widths, heights, derivatives, cumwidths, cumheights = \ _prepare_spline_params(widths, heights, derivatives, tail_bound) inside = (y >= -tail_bound) & (y <= tail_bound) x = y.clone() log_det = torch.zeros_like(y) if not inside.any(): return x, log_det y_in = y[inside] ch_in = cumheights[inside] bin_idx = torch.searchsorted(ch_in[:, 1:].contiguous(), y_in.unsqueeze(-1)).squeeze(-1) bin_idx = bin_idx.clamp(0, K - 1) idx = bin_idx.unsqueeze(-1) w_k = widths[inside].gather(-1, idx).squeeze(-1) h_k = heights[inside].gather(-1, idx).squeeze(-1) d_k = derivatives[inside].gather(-1, idx).squeeze(-1) d_k1 = derivatives[inside].gather(-1, idx + 1).squeeze(-1) cw_k = cumwidths[inside].gather(-1, idx).squeeze(-1) ch_k = ch_in.gather(-1, idx).squeeze(-1) s_k = h_k / w_k y_rel = y_in - ch_k a = h_k * (s_k - d_k) + y_rel * (d_k + d_k1 - 2 * s_k) b = h_k * d_k - y_rel * (d_k + d_k1 - 2 * s_k) c = -s_k * y_rel discriminant = b.pow(2) - 4 * a * c sqrt_disc = torch.sqrt(discriminant.clamp(min=1e-8)) xi = (2 * c) / (-b - sqrt_disc).clamp(max=-1e-8) xi = xi.clamp(1e-6, 1.0 - 1e-6) x[inside] = cw_k + w_k * xi deriv_numer = s_k.pow(2) * (d_k1 * xi.pow(2) + 2 * s_k * xi * (1 - xi) + d_k * (1 - xi).pow(2)) denom = s_k + (d_k + d_k1 - 2 * s_k) * xi * (1 - xi) log_det[inside] = torch.log(deriv_numer.clamp(min=1e-8)) - 2 * torch.log(denom.abs().clamp(min=1e-8)) log_det = log_det.clamp(-20.0, 20.0) return x, -log_det class ConditionalSplineCoupling(nn.Module): """ Conditional Neural Spline coupling layer (Durkan et al. 2019). Same interface as ConditionalAffineCoupling but uses rational-quadratic splines instead of affine transforms. Much more expressive per layer. Args: dim: parameter dimension d_context: context vector dimension hidden_dim: hidden layer size in conditioner network mask_type: 'even' or 'odd' n_bins: number of spline bins (K) tail_bound: spline domain [-B, B] """ def __init__(self, dim, d_context, hidden_dim=128, mask_type='even', n_bins=8, tail_bound=5.0): super().__init__() self.dim = dim self.n_bins = n_bins self.tail_bound = tail_bound if mask_type == 'even': self.register_buffer('mask', torch.arange(dim) % 2 == 0) else: self.register_buffer('mask', torch.arange(dim) % 2 == 1) n_a = self.mask.sum().item() n_b = dim - n_a self.n_b = n_b # Output: K widths + K heights + (K+1) derivatives per transformed dim n_out = n_b * (3 * n_bins + 1) self.net = nn.Sequential( nn.Linear(n_a + d_context, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, n_out), ) nn.init.zeros_(self.net[-1].weight) nn.init.zeros_(self.net[-1].bias) def _get_spline_params(self, z_a, context): """Compute spline parameters from conditioner network.""" raw = self.net(torch.cat([z_a, context], dim=-1)) # [B, n_b*(3K+1)] raw = raw.view(-1, self.n_b, 3 * self.n_bins + 1) K = self.n_bins widths = raw[..., :K] heights = raw[..., K:2*K] derivatives = raw[..., 2*K:] # Bound raw outputs to prevent extreme spline configurations. # softmax(widths/heights) is shift-invariant so bounding doesn't # reduce expressiveness, it just prevents near-one-hot bin allocations. # softplus(derivatives) with bounded input caps the knot slopes. widths = widths.clamp(-5.0, 5.0) heights = heights.clamp(-5.0, 5.0) derivatives = derivatives.clamp(-5.0, 5.0) return widths, heights, derivatives def forward(self, z, context): z_a = z[:, self.mask] z_b = z[:, ~self.mask] widths, heights, derivatives = self._get_spline_params(z_a, context) theta_b, log_det_per_dim = rational_quadratic_spline_forward( z_b, widths, heights, derivatives, self.tail_bound) log_det = log_det_per_dim.sum(dim=-1) theta = torch.empty_like(z) theta[:, self.mask] = z_a theta[:, ~self.mask] = theta_b return theta, log_det def inverse(self, theta, context): theta_a = theta[:, self.mask] theta_b = theta[:, ~self.mask] widths, heights, derivatives = self._get_spline_params(theta_a, context) z_b, log_det_per_dim = rational_quadratic_spline_inverse( theta_b, widths, heights, derivatives, self.tail_bound) log_det = log_det_per_dim.sum(dim=-1) z = torch.empty_like(theta) z[:, self.mask] = theta_a z[:, ~self.mask] = z_b return z, log_det # ============================================================================= # Full Conditional Normalizing Flow # ============================================================================= class ConditionalFlow(nn.Module): """ Conditional normalizing flow for p(theta | x). Maps between base distribution z ~ N(0, I) and parameter space theta, conditioned on the observed signal x. Supports both affine and spline coupling layers. """ def __init__( self, theta_dim=3, d_context=128, n_coupling_layers=8, hidden_dim=128, d_model=128, coupling_type='spline', n_bins=8, tail_bound=5.0, ): super().__init__() self.theta_dim = theta_dim self.d_context = d_context self.coupling_type = coupling_type self.tail_bound = tail_bound self.encoder = SignalEncoder( in_channels=3, d_model=d_model, d_context=d_context ) self.flows = nn.ModuleList() for i in range(n_coupling_layers): mask_type = 'even' if i % 2 == 0 else 'odd' self.flows.append(ActNorm(theta_dim)) if coupling_type == 'spline': self.flows.append( ConditionalSplineCoupling( dim=theta_dim, d_context=d_context, hidden_dim=hidden_dim, mask_type=mask_type, n_bins=n_bins, tail_bound=tail_bound, ) ) else: self.flows.append( ConditionalAffineCoupling( dim=theta_dim, d_context=d_context, hidden_dim=hidden_dim, mask_type=mask_type, ) ) self.register_buffer('theta_mean', torch.zeros(theta_dim)) self.register_buffer('theta_std', torch.ones(theta_dim)) def set_theta_stats(self, mean, std): self.theta_mean.copy_(torch.as_tensor(mean, dtype=torch.float32)) self.theta_std.copy_(torch.as_tensor(std, dtype=torch.float32)) def normalize_theta(self, theta): return (theta - self.theta_mean) / self.theta_std def denormalize_theta(self, theta_norm): return theta_norm * self.theta_std + self.theta_mean def encode_signal(self, x, mask=None): return self.encoder(x, mask=mask) def forward_flow(self, z, context): total_log_det = torch.zeros(z.shape[0], device=z.device) h = z for layer in self.flows: if isinstance(layer, ActNorm): h, ld = layer(h) total_log_det += ld else: h, ld = layer(h, context) total_log_det += ld return h, total_log_det def inverse_flow(self, theta_norm, context): total_log_det = torch.zeros(theta_norm.shape[0], device=theta_norm.device) h = theta_norm for layer in reversed(self.flows): if isinstance(layer, ActNorm): h = layer.inverse(h) total_log_det -= layer.log_scale.sum() else: h, ld = layer.inverse(h, context) total_log_det += ld return h, total_log_det def log_prob(self, theta, x, mask=None): context = self.encode_signal(x, mask=mask) theta_norm = self.normalize_theta(theta) theta_norm = theta_norm.clamp(-self.tail_bound, self.tail_bound) z, log_det = self.inverse_flow(theta_norm, context) log_pz = -0.5 * (z ** 2 + math.log(2 * math.pi)).sum(dim=-1) log_det_norm = -torch.log(self.theta_std).sum() log_p = log_pz + log_det + log_det_norm return log_p.clamp(min=-50.0, max=50.0) @torch.no_grad() def sample(self, x, mask=None, n_samples=100): B = x.shape[0] context = self.encode_signal(x, mask=mask) context_rep = context.unsqueeze(1).expand(-1, n_samples, -1) context_rep = context_rep.reshape(B * n_samples, -1) z = torch.randn(B * n_samples, self.theta_dim, device=x.device) theta_norm, _ = self.forward_flow(z, context_rep) theta = self.denormalize_theta(theta_norm) return theta.reshape(B, n_samples, self.theta_dim) @torch.no_grad() def posterior_stats(self, x, mask=None, n_samples=1000): samples = self.sample(x, mask=mask, n_samples=n_samples) return { 'mean': samples.mean(dim=1), 'std': samples.std(dim=1), 'median': samples.median(dim=1).values, 'q05': samples.quantile(0.05, dim=1), 'q95': samples.quantile(0.95, dim=1), 'samples': samples, } # Backward-compatible alias ConditionalRealNVP = ConditionalFlow # ============================================================================= # Helper # ============================================================================= THETA_NAMES = ['log10(K0)', 'alpha', 'log10(dB)'] # Per-mechanism parameter definitions (variable dim per mechanism) MECHANISM_PARAMS = { 'Nernst': { 'names': ['E0_offset', 'log10(dA)', 'log10(dB)'], 'dim': 3, }, 'BV': { 'names': ['log10(K0)', 'alpha', 'log10(dB)'], 'dim': 3, }, 'MHC': { 'names': ['log10(K0)', 'log10(reorg_e)', 'log10(dB)'], 'dim': 3, }, 'Ads': { 'names': ['log10(K0)', 'alpha', 'log10(Gamma_sat)'], 'dim': 3, }, 'EC': { 'names': ['log10(K0)', 'alpha', 'log10(kc)', 'log10(dB)'], 'dim': 4, }, 'LH': { 'names': ['log10(K0)', 'alpha', 'log10(KA_eq)', 'log10(KB_eq)', 'log10(dB)'], 'dim': 5, }, } MECHANISM_LIST = ['Nernst', 'BV', 'MHC', 'Ads', 'EC', 'LH'] def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) if __name__ == "__main__": B, T = 4, 800 theta_dim = 3 x = torch.randn(B, 3, T) theta = torch.randn(B, theta_dim) mask = torch.ones(B, T, dtype=torch.bool) for coupling_type in ['affine', 'spline']: print(f"\n{'=' * 50}") print(f"Testing ConditionalFlow (coupling={coupling_type})") print(f"{'=' * 50}") model = ConditionalFlow( theta_dim=theta_dim, d_context=128, n_coupling_layers=8, hidden_dim=128, d_model=128, coupling_type=coupling_type, ) print(f"Parameters: {count_parameters(model):,}") log_q = model.log_prob(theta, x, mask=mask) print(f"log_prob shape: {log_q.shape}, values: {log_q}") samples = model.sample(x, mask=mask, n_samples=100) print(f"Samples shape: {samples.shape}") print(f"Sample mean: {samples.mean(dim=1)}") print(f"Sample std: {samples.std(dim=1)}")