| """ |
| Conditional Normalizing Flow for Electrochemical Parameter Inference. |
| |
| Models p(theta | x) where: |
| theta = simulator parameters (variable dimension per mechanism) |
| x = (E(t), j(t), t) -- observed electrochemical signal [3, T] |
| |
| Supports two coupling layer types: |
| - Affine: simple scale+shift (original RealNVP) |
| - Spline: rational-quadratic spline (Neural Spline Flows, Durkan et al. 2019) |
| |
| Architecture: |
| 1. SignalEncoder: 1D CNN + global pooling -> fixed-size context vector |
| 2. ConditionalFlow: Coupling layers (affine or spline) conditioned on context |
| |
| Training objective: Negative log-likelihood (NLL) |
| L = -E[log q_phi(theta | x)] |
| = -E[log p_z(f^{-1}(theta; x)) + log |det df^{-1}/d_theta|] |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| |
| |
| |
|
|
| class SignalEncoder(nn.Module): |
| """ |
| Encode variable-length electrochemical waveform into a fixed-size context vector. |
| |
| Architecture: 1D CNN -> Global Average Pooling -> MLP |
| Input: [B, 3, T] (potential, flux, time) |
| Output: [B, d_context] |
| """ |
| def __init__(self, in_channels=3, d_model=128, d_context=128): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv1d(in_channels, d_model // 2, kernel_size=7, padding=3), |
| nn.GELU(), |
| nn.Conv1d(d_model // 2, d_model, kernel_size=5, padding=2), |
| nn.GELU(), |
| nn.Conv1d(d_model, d_model, kernel_size=3, padding=1), |
| nn.GELU(), |
| ) |
| self.pool_proj = nn.Sequential( |
| nn.Linear(d_model, d_context), |
| nn.GELU(), |
| nn.Linear(d_context, d_context), |
| ) |
|
|
| def forward(self, x, mask=None): |
| h = self.conv(x) |
| if mask is not None: |
| mask_expanded = mask.unsqueeze(1).float() |
| h = (h * mask_expanded).sum(dim=-1) / mask_expanded.sum(dim=-1).clamp(min=1) |
| else: |
| h = h.mean(dim=-1) |
| context = self.pool_proj(h) |
| return context |
|
|
|
|
| |
| |
| |
|
|
| class ActNorm(nn.Module): |
| """Activation normalization (from Glow) with data-dependent init.""" |
| def __init__(self, dim): |
| super().__init__() |
| self.log_scale = nn.Parameter(torch.zeros(dim)) |
| self.bias = nn.Parameter(torch.zeros(dim)) |
| self.register_buffer('_initialized', torch.tensor(False)) |
|
|
| @property |
| def initialized(self): |
| return bool(self._initialized.item()) |
|
|
| @initialized.setter |
| def initialized(self, value): |
| self._initialized.fill_(value) |
|
|
| def initialize(self, x): |
| with torch.no_grad(): |
| self.bias.data = -x.mean(dim=0) |
| if x.shape[0] > 1: |
| std = x.std(dim=0).clamp(min=0.1) |
| self.log_scale.data = -torch.log(std) |
| else: |
| self.log_scale.data.zero_() |
| self.initialized = True |
|
|
| def forward(self, x): |
| if not self.initialized: |
| self.initialize(x) |
| y = (x + self.bias) * torch.exp(self.log_scale) |
| log_det = self.log_scale.sum() |
| return y, log_det |
|
|
| def inverse(self, y): |
| x = y * torch.exp(-self.log_scale) - self.bias |
| return x |
|
|
|
|
| class ConditionalAffineCoupling(nn.Module): |
| """ |
| Conditional affine coupling layer. |
| |
| Forward (z -> theta): theta_b = z_b * exp(s) + t |
| Inverse (theta -> z): z_b = (theta_b - t) * exp(-s) |
| |
| The log-scale s is soft-clamped to [-s_clamp, s_clamp]. With s_clamp=2.0 |
| each layer can scale by up to exp(2)≈7.4x per dimension, giving the flow |
| enough dynamic range to produce both narrow (identifiable) and wide |
| (non-identifiable) posteriors. |
| """ |
| def __init__(self, dim, d_context, hidden_dim=128, mask_type='even', s_clamp=2.0): |
| super().__init__() |
| self.dim = dim |
| self.s_clamp = s_clamp |
|
|
| if mask_type == 'even': |
| self.register_buffer('mask', torch.arange(dim) % 2 == 0) |
| else: |
| self.register_buffer('mask', torch.arange(dim) % 2 == 1) |
|
|
| n_a = self.mask.sum().item() |
| n_b = dim - n_a |
|
|
| self.net = nn.Sequential( |
| nn.Linear(n_a + d_context, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, 2 * n_b), |
| ) |
| nn.init.zeros_(self.net[-1].weight) |
| nn.init.zeros_(self.net[-1].bias) |
|
|
| def _clamp_s(self, s_raw): |
| """Soft-clamp log-scale using 2*tanh(s/2) for smooth gradients everywhere.""" |
| return self.s_clamp * torch.tanh(s_raw / self.s_clamp) |
|
|
| def forward(self, z, context): |
| z_a = z[:, self.mask] |
| z_b = z[:, ~self.mask] |
|
|
| st = self.net(torch.cat([z_a, context], dim=-1)) |
| s, t = st.chunk(2, dim=-1) |
| s = self._clamp_s(s) |
|
|
| theta_b = z_b * torch.exp(s) + t |
| log_det = s.sum(dim=-1) |
|
|
| theta = torch.empty_like(z) |
| theta[:, self.mask] = z_a |
| theta[:, ~self.mask] = theta_b |
| return theta, log_det |
|
|
| def inverse(self, theta, context): |
| theta_a = theta[:, self.mask] |
| theta_b = theta[:, ~self.mask] |
|
|
| st = self.net(torch.cat([theta_a, context], dim=-1)) |
| s, t = st.chunk(2, dim=-1) |
| s = self._clamp_s(s) |
|
|
| z_b = (theta_b - t) * torch.exp(-s) |
| log_det = -s.sum(dim=-1) |
|
|
| z = torch.empty_like(theta) |
| z[:, self.mask] = theta_a |
| z[:, ~self.mask] = z_b |
| return z, log_det |
|
|
|
|
| |
| |
| |
|
|
| MIN_BIN_FRACTION = 1e-2 |
| MIN_DERIVATIVE = 1e-2 |
|
|
|
|
| def _prepare_spline_params(widths, heights, derivatives, tail_bound): |
| """Shared preprocessing for forward and inverse spline transforms. |
| |
| Enforces minimum bin width/height (following nflows convention) to prevent |
| degenerate near-step-function splines that break invertibility. |
| """ |
| K = widths.shape[-1] |
| total = 2 * tail_bound |
|
|
| widths = F.softmax(widths, dim=-1) |
| widths = MIN_BIN_FRACTION + (1 - K * MIN_BIN_FRACTION) * widths |
| widths = widths * total |
|
|
| heights = F.softmax(heights, dim=-1) |
| heights = MIN_BIN_FRACTION + (1 - K * MIN_BIN_FRACTION) * heights |
| heights = heights * total |
|
|
| derivatives = F.softplus(derivatives) + MIN_DERIVATIVE |
|
|
| cumwidths = torch.cumsum(widths, dim=-1) |
| cumwidths = F.pad(cumwidths, (1, 0), value=0.0) - tail_bound |
| cumheights = torch.cumsum(heights, dim=-1) |
| cumheights = F.pad(cumheights, (1, 0), value=0.0) - tail_bound |
|
|
| return widths, heights, derivatives, cumwidths, cumheights |
|
|
|
|
| def rational_quadratic_spline_forward(x, widths, heights, derivatives, tail_bound=5.0): |
| """ |
| Apply monotonic rational-quadratic spline transform (forward direction). |
| |
| Uses identity transform outside [-tail_bound, tail_bound]. |
| |
| Args: |
| x: [B, D] input values |
| widths: [B, D, K] bin widths (pre-softmax) |
| heights: [B, D, K] bin heights (pre-softmax) |
| derivatives: [B, D, K+1] knot derivatives (pre-softplus) |
| Returns: |
| y: [B, D] transformed values |
| log_det: [B, D] log absolute derivative per dimension |
| """ |
| K = widths.shape[-1] |
| widths, heights, derivatives, cumwidths, cumheights = \ |
| _prepare_spline_params(widths, heights, derivatives, tail_bound) |
|
|
| inside = (x >= -tail_bound) & (x <= tail_bound) |
|
|
| |
| y = x.clone() |
| log_det = torch.zeros_like(x) |
|
|
| if not inside.any(): |
| return y, log_det |
|
|
| x_in = x[inside] |
|
|
| |
| |
| cw_in = cumwidths[inside] |
| bin_idx = torch.searchsorted(cw_in[:, 1:].contiguous(), x_in.unsqueeze(-1)).squeeze(-1) |
| bin_idx = bin_idx.clamp(0, K - 1) |
|
|
| idx = bin_idx.unsqueeze(-1) |
| w_k = widths[inside].gather(-1, idx).squeeze(-1) |
| h_k = heights[inside].gather(-1, idx).squeeze(-1) |
| d_k = derivatives[inside].gather(-1, idx).squeeze(-1) |
| d_k1 = derivatives[inside].gather(-1, idx + 1).squeeze(-1) |
| cw_k = cw_in.gather(-1, idx).squeeze(-1) |
| ch_k = cumheights[inside].gather(-1, idx).squeeze(-1) |
|
|
| xi = (x_in - cw_k) / w_k |
| xi = xi.clamp(1e-6, 1.0 - 1e-6) |
|
|
| s_k = h_k / w_k |
| numer = h_k * (s_k * xi.pow(2) + d_k * xi * (1 - xi)) |
| denom = s_k + (d_k + d_k1 - 2 * s_k) * xi * (1 - xi) |
| y[inside] = ch_k + numer / denom |
|
|
| deriv_numer = s_k.pow(2) * (d_k1 * xi.pow(2) + 2 * s_k * xi * (1 - xi) + d_k * (1 - xi).pow(2)) |
| log_det[inside] = torch.log(deriv_numer.clamp(min=1e-8)) - 2 * torch.log(denom.abs().clamp(min=1e-8)) |
|
|
| log_det = log_det.clamp(-20.0, 20.0) |
| return y, log_det |
|
|
|
|
| def rational_quadratic_spline_inverse(y, widths, heights, derivatives, tail_bound=5.0): |
| """ |
| Apply inverse rational-quadratic spline transform. |
| |
| Uses identity transform outside [-tail_bound, tail_bound]. |
| """ |
| K = widths.shape[-1] |
| widths, heights, derivatives, cumwidths, cumheights = \ |
| _prepare_spline_params(widths, heights, derivatives, tail_bound) |
|
|
| inside = (y >= -tail_bound) & (y <= tail_bound) |
|
|
| x = y.clone() |
| log_det = torch.zeros_like(y) |
|
|
| if not inside.any(): |
| return x, log_det |
|
|
| y_in = y[inside] |
|
|
| ch_in = cumheights[inside] |
| bin_idx = torch.searchsorted(ch_in[:, 1:].contiguous(), y_in.unsqueeze(-1)).squeeze(-1) |
| bin_idx = bin_idx.clamp(0, K - 1) |
|
|
| idx = bin_idx.unsqueeze(-1) |
| w_k = widths[inside].gather(-1, idx).squeeze(-1) |
| h_k = heights[inside].gather(-1, idx).squeeze(-1) |
| d_k = derivatives[inside].gather(-1, idx).squeeze(-1) |
| d_k1 = derivatives[inside].gather(-1, idx + 1).squeeze(-1) |
| cw_k = cumwidths[inside].gather(-1, idx).squeeze(-1) |
| ch_k = ch_in.gather(-1, idx).squeeze(-1) |
|
|
| s_k = h_k / w_k |
| y_rel = y_in - ch_k |
|
|
| a = h_k * (s_k - d_k) + y_rel * (d_k + d_k1 - 2 * s_k) |
| b = h_k * d_k - y_rel * (d_k + d_k1 - 2 * s_k) |
| c = -s_k * y_rel |
|
|
| discriminant = b.pow(2) - 4 * a * c |
| sqrt_disc = torch.sqrt(discriminant.clamp(min=1e-8)) |
| xi = (2 * c) / (-b - sqrt_disc).clamp(max=-1e-8) |
| xi = xi.clamp(1e-6, 1.0 - 1e-6) |
|
|
| x[inside] = cw_k + w_k * xi |
|
|
| deriv_numer = s_k.pow(2) * (d_k1 * xi.pow(2) + 2 * s_k * xi * (1 - xi) + d_k * (1 - xi).pow(2)) |
| denom = s_k + (d_k + d_k1 - 2 * s_k) * xi * (1 - xi) |
| log_det[inside] = torch.log(deriv_numer.clamp(min=1e-8)) - 2 * torch.log(denom.abs().clamp(min=1e-8)) |
|
|
| log_det = log_det.clamp(-20.0, 20.0) |
| return x, -log_det |
|
|
|
|
| class ConditionalSplineCoupling(nn.Module): |
| """ |
| Conditional Neural Spline coupling layer (Durkan et al. 2019). |
| |
| Same interface as ConditionalAffineCoupling but uses rational-quadratic |
| splines instead of affine transforms. Much more expressive per layer. |
| |
| Args: |
| dim: parameter dimension |
| d_context: context vector dimension |
| hidden_dim: hidden layer size in conditioner network |
| mask_type: 'even' or 'odd' |
| n_bins: number of spline bins (K) |
| tail_bound: spline domain [-B, B] |
| """ |
| def __init__(self, dim, d_context, hidden_dim=128, mask_type='even', |
| n_bins=8, tail_bound=5.0): |
| super().__init__() |
| self.dim = dim |
| self.n_bins = n_bins |
| self.tail_bound = tail_bound |
|
|
| if mask_type == 'even': |
| self.register_buffer('mask', torch.arange(dim) % 2 == 0) |
| else: |
| self.register_buffer('mask', torch.arange(dim) % 2 == 1) |
|
|
| n_a = self.mask.sum().item() |
| n_b = dim - n_a |
| self.n_b = n_b |
|
|
| |
| n_out = n_b * (3 * n_bins + 1) |
|
|
| self.net = nn.Sequential( |
| nn.Linear(n_a + d_context, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, n_out), |
| ) |
| nn.init.zeros_(self.net[-1].weight) |
| nn.init.zeros_(self.net[-1].bias) |
|
|
| def _get_spline_params(self, z_a, context): |
| """Compute spline parameters from conditioner network.""" |
| raw = self.net(torch.cat([z_a, context], dim=-1)) |
| raw = raw.view(-1, self.n_b, 3 * self.n_bins + 1) |
| K = self.n_bins |
| widths = raw[..., :K] |
| heights = raw[..., K:2*K] |
| derivatives = raw[..., 2*K:] |
| |
| |
| |
| |
| widths = widths.clamp(-5.0, 5.0) |
| heights = heights.clamp(-5.0, 5.0) |
| derivatives = derivatives.clamp(-5.0, 5.0) |
| return widths, heights, derivatives |
|
|
| def forward(self, z, context): |
| z_a = z[:, self.mask] |
| z_b = z[:, ~self.mask] |
|
|
| widths, heights, derivatives = self._get_spline_params(z_a, context) |
| theta_b, log_det_per_dim = rational_quadratic_spline_forward( |
| z_b, widths, heights, derivatives, self.tail_bound) |
| log_det = log_det_per_dim.sum(dim=-1) |
|
|
| theta = torch.empty_like(z) |
| theta[:, self.mask] = z_a |
| theta[:, ~self.mask] = theta_b |
| return theta, log_det |
|
|
| def inverse(self, theta, context): |
| theta_a = theta[:, self.mask] |
| theta_b = theta[:, ~self.mask] |
|
|
| widths, heights, derivatives = self._get_spline_params(theta_a, context) |
| z_b, log_det_per_dim = rational_quadratic_spline_inverse( |
| theta_b, widths, heights, derivatives, self.tail_bound) |
| log_det = log_det_per_dim.sum(dim=-1) |
|
|
| z = torch.empty_like(theta) |
| z[:, self.mask] = theta_a |
| z[:, ~self.mask] = z_b |
| return z, log_det |
|
|
|
|
| |
| |
| |
|
|
| class ConditionalFlow(nn.Module): |
| """ |
| Conditional normalizing flow for p(theta | x). |
| |
| Maps between base distribution z ~ N(0, I) and parameter space theta, |
| conditioned on the observed signal x. |
| |
| Supports both affine and spline coupling layers. |
| """ |
| def __init__( |
| self, |
| theta_dim=3, |
| d_context=128, |
| n_coupling_layers=8, |
| hidden_dim=128, |
| d_model=128, |
| coupling_type='spline', |
| n_bins=8, |
| tail_bound=5.0, |
| ): |
| super().__init__() |
| self.theta_dim = theta_dim |
| self.d_context = d_context |
| self.coupling_type = coupling_type |
| self.tail_bound = tail_bound |
|
|
| self.encoder = SignalEncoder( |
| in_channels=3, d_model=d_model, d_context=d_context |
| ) |
|
|
| self.flows = nn.ModuleList() |
| for i in range(n_coupling_layers): |
| mask_type = 'even' if i % 2 == 0 else 'odd' |
| self.flows.append(ActNorm(theta_dim)) |
| if coupling_type == 'spline': |
| self.flows.append( |
| ConditionalSplineCoupling( |
| dim=theta_dim, |
| d_context=d_context, |
| hidden_dim=hidden_dim, |
| mask_type=mask_type, |
| n_bins=n_bins, |
| tail_bound=tail_bound, |
| ) |
| ) |
| else: |
| self.flows.append( |
| ConditionalAffineCoupling( |
| dim=theta_dim, |
| d_context=d_context, |
| hidden_dim=hidden_dim, |
| mask_type=mask_type, |
| ) |
| ) |
|
|
| self.register_buffer('theta_mean', torch.zeros(theta_dim)) |
| self.register_buffer('theta_std', torch.ones(theta_dim)) |
|
|
| def set_theta_stats(self, mean, std): |
| self.theta_mean.copy_(torch.as_tensor(mean, dtype=torch.float32)) |
| self.theta_std.copy_(torch.as_tensor(std, dtype=torch.float32)) |
|
|
| def normalize_theta(self, theta): |
| return (theta - self.theta_mean) / self.theta_std |
|
|
| def denormalize_theta(self, theta_norm): |
| return theta_norm * self.theta_std + self.theta_mean |
|
|
| def encode_signal(self, x, mask=None): |
| return self.encoder(x, mask=mask) |
|
|
| def forward_flow(self, z, context): |
| total_log_det = torch.zeros(z.shape[0], device=z.device) |
| h = z |
| for layer in self.flows: |
| if isinstance(layer, ActNorm): |
| h, ld = layer(h) |
| total_log_det += ld |
| else: |
| h, ld = layer(h, context) |
| total_log_det += ld |
| return h, total_log_det |
|
|
| def inverse_flow(self, theta_norm, context): |
| total_log_det = torch.zeros(theta_norm.shape[0], device=theta_norm.device) |
| h = theta_norm |
| for layer in reversed(self.flows): |
| if isinstance(layer, ActNorm): |
| h = layer.inverse(h) |
| total_log_det -= layer.log_scale.sum() |
| else: |
| h, ld = layer.inverse(h, context) |
| total_log_det += ld |
| return h, total_log_det |
|
|
| def log_prob(self, theta, x, mask=None): |
| context = self.encode_signal(x, mask=mask) |
| theta_norm = self.normalize_theta(theta) |
| theta_norm = theta_norm.clamp(-self.tail_bound, self.tail_bound) |
| z, log_det = self.inverse_flow(theta_norm, context) |
| log_pz = -0.5 * (z ** 2 + math.log(2 * math.pi)).sum(dim=-1) |
| log_det_norm = -torch.log(self.theta_std).sum() |
| log_p = log_pz + log_det + log_det_norm |
| return log_p.clamp(min=-50.0, max=50.0) |
|
|
| @torch.no_grad() |
| def sample(self, x, mask=None, n_samples=100): |
| B = x.shape[0] |
| context = self.encode_signal(x, mask=mask) |
| context_rep = context.unsqueeze(1).expand(-1, n_samples, -1) |
| context_rep = context_rep.reshape(B * n_samples, -1) |
| z = torch.randn(B * n_samples, self.theta_dim, device=x.device) |
| theta_norm, _ = self.forward_flow(z, context_rep) |
| theta = self.denormalize_theta(theta_norm) |
| return theta.reshape(B, n_samples, self.theta_dim) |
|
|
| @torch.no_grad() |
| def posterior_stats(self, x, mask=None, n_samples=1000): |
| samples = self.sample(x, mask=mask, n_samples=n_samples) |
| return { |
| 'mean': samples.mean(dim=1), |
| 'std': samples.std(dim=1), |
| 'median': samples.median(dim=1).values, |
| 'q05': samples.quantile(0.05, dim=1), |
| 'q95': samples.quantile(0.95, dim=1), |
| 'samples': samples, |
| } |
|
|
|
|
| |
| ConditionalRealNVP = ConditionalFlow |
|
|
|
|
| |
| |
| |
|
|
| THETA_NAMES = ['log10(K0)', 'alpha', 'log10(dB)'] |
|
|
| |
| MECHANISM_PARAMS = { |
| 'Nernst': { |
| 'names': ['E0_offset', 'log10(dA)', 'log10(dB)'], |
| 'dim': 3, |
| }, |
| 'BV': { |
| 'names': ['log10(K0)', 'alpha', 'log10(dB)'], |
| 'dim': 3, |
| }, |
| 'MHC': { |
| 'names': ['log10(K0)', 'log10(reorg_e)', 'log10(dB)'], |
| 'dim': 3, |
| }, |
| 'Ads': { |
| 'names': ['log10(K0)', 'alpha', 'log10(Gamma_sat)'], |
| 'dim': 3, |
| }, |
| 'EC': { |
| 'names': ['log10(K0)', 'alpha', 'log10(kc)', 'log10(dB)'], |
| 'dim': 4, |
| }, |
| 'LH': { |
| 'names': ['log10(K0)', 'alpha', 'log10(KA_eq)', 'log10(KB_eq)', 'log10(dB)'], |
| 'dim': 5, |
| }, |
| } |
|
|
| MECHANISM_LIST = ['Nernst', 'BV', 'MHC', 'Ads', 'EC', 'LH'] |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| if __name__ == "__main__": |
| B, T = 4, 800 |
| theta_dim = 3 |
|
|
| x = torch.randn(B, 3, T) |
| theta = torch.randn(B, theta_dim) |
| mask = torch.ones(B, T, dtype=torch.bool) |
|
|
| for coupling_type in ['affine', 'spline']: |
| print(f"\n{'=' * 50}") |
| print(f"Testing ConditionalFlow (coupling={coupling_type})") |
| print(f"{'=' * 50}") |
|
|
| model = ConditionalFlow( |
| theta_dim=theta_dim, |
| d_context=128, |
| n_coupling_layers=8, |
| hidden_dim=128, |
| d_model=128, |
| coupling_type=coupling_type, |
| ) |
|
|
| print(f"Parameters: {count_parameters(model):,}") |
|
|
| log_q = model.log_prob(theta, x, mask=mask) |
| print(f"log_prob shape: {log_q.shape}, values: {log_q}") |
|
|
| samples = model.sample(x, mask=mask, n_samples=100) |
| print(f"Samples shape: {samples.shape}") |
| print(f"Sample mean: {samples.mean(dim=1)}") |
| print(f"Sample std: {samples.std(dim=1)}") |
|
|