ecflow / flow_model.py
Bing Yan
Initial ECFlow deployment
d6b782e
"""
Conditional Normalizing Flow for Electrochemical Parameter Inference.
Models p(theta | x) where:
theta = simulator parameters (variable dimension per mechanism)
x = (E(t), j(t), t) -- observed electrochemical signal [3, T]
Supports two coupling layer types:
- Affine: simple scale+shift (original RealNVP)
- Spline: rational-quadratic spline (Neural Spline Flows, Durkan et al. 2019)
Architecture:
1. SignalEncoder: 1D CNN + global pooling -> fixed-size context vector
2. ConditionalFlow: Coupling layers (affine or spline) conditioned on context
Training objective: Negative log-likelihood (NLL)
L = -E[log q_phi(theta | x)]
= -E[log p_z(f^{-1}(theta; x)) + log |det df^{-1}/d_theta|]
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# =============================================================================
# Signal Encoder: x -> context vector
# =============================================================================
class SignalEncoder(nn.Module):
"""
Encode variable-length electrochemical waveform into a fixed-size context vector.
Architecture: 1D CNN -> Global Average Pooling -> MLP
Input: [B, 3, T] (potential, flux, time)
Output: [B, d_context]
"""
def __init__(self, in_channels=3, d_model=128, d_context=128):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(in_channels, d_model // 2, kernel_size=7, padding=3),
nn.GELU(),
nn.Conv1d(d_model // 2, d_model, kernel_size=5, padding=2),
nn.GELU(),
nn.Conv1d(d_model, d_model, kernel_size=3, padding=1),
nn.GELU(),
)
self.pool_proj = nn.Sequential(
nn.Linear(d_model, d_context),
nn.GELU(),
nn.Linear(d_context, d_context),
)
def forward(self, x, mask=None):
h = self.conv(x)
if mask is not None:
mask_expanded = mask.unsqueeze(1).float()
h = (h * mask_expanded).sum(dim=-1) / mask_expanded.sum(dim=-1).clamp(min=1)
else:
h = h.mean(dim=-1)
context = self.pool_proj(h)
return context
# =============================================================================
# Normalizing Flow Components
# =============================================================================
class ActNorm(nn.Module):
"""Activation normalization (from Glow) with data-dependent init."""
def __init__(self, dim):
super().__init__()
self.log_scale = nn.Parameter(torch.zeros(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.register_buffer('_initialized', torch.tensor(False))
@property
def initialized(self):
return bool(self._initialized.item())
@initialized.setter
def initialized(self, value):
self._initialized.fill_(value)
def initialize(self, x):
with torch.no_grad():
self.bias.data = -x.mean(dim=0)
if x.shape[0] > 1:
std = x.std(dim=0).clamp(min=0.1)
self.log_scale.data = -torch.log(std)
else:
self.log_scale.data.zero_()
self.initialized = True
def forward(self, x):
if not self.initialized:
self.initialize(x)
y = (x + self.bias) * torch.exp(self.log_scale)
log_det = self.log_scale.sum()
return y, log_det
def inverse(self, y):
x = y * torch.exp(-self.log_scale) - self.bias
return x
class ConditionalAffineCoupling(nn.Module):
"""
Conditional affine coupling layer.
Forward (z -> theta): theta_b = z_b * exp(s) + t
Inverse (theta -> z): z_b = (theta_b - t) * exp(-s)
The log-scale s is soft-clamped to [-s_clamp, s_clamp]. With s_clamp=2.0
each layer can scale by up to exp(2)≈7.4x per dimension, giving the flow
enough dynamic range to produce both narrow (identifiable) and wide
(non-identifiable) posteriors.
"""
def __init__(self, dim, d_context, hidden_dim=128, mask_type='even', s_clamp=2.0):
super().__init__()
self.dim = dim
self.s_clamp = s_clamp
if mask_type == 'even':
self.register_buffer('mask', torch.arange(dim) % 2 == 0)
else:
self.register_buffer('mask', torch.arange(dim) % 2 == 1)
n_a = self.mask.sum().item()
n_b = dim - n_a
self.net = nn.Sequential(
nn.Linear(n_a + d_context, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 2 * n_b),
)
nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)
def _clamp_s(self, s_raw):
"""Soft-clamp log-scale using 2*tanh(s/2) for smooth gradients everywhere."""
return self.s_clamp * torch.tanh(s_raw / self.s_clamp)
def forward(self, z, context):
z_a = z[:, self.mask]
z_b = z[:, ~self.mask]
st = self.net(torch.cat([z_a, context], dim=-1))
s, t = st.chunk(2, dim=-1)
s = self._clamp_s(s)
theta_b = z_b * torch.exp(s) + t
log_det = s.sum(dim=-1)
theta = torch.empty_like(z)
theta[:, self.mask] = z_a
theta[:, ~self.mask] = theta_b
return theta, log_det
def inverse(self, theta, context):
theta_a = theta[:, self.mask]
theta_b = theta[:, ~self.mask]
st = self.net(torch.cat([theta_a, context], dim=-1))
s, t = st.chunk(2, dim=-1)
s = self._clamp_s(s)
z_b = (theta_b - t) * torch.exp(-s)
log_det = -s.sum(dim=-1)
z = torch.empty_like(theta)
z[:, self.mask] = theta_a
z[:, ~self.mask] = z_b
return z, log_det
# =============================================================================
# Rational-Quadratic Spline Transform (Durkan et al. 2019)
# =============================================================================
MIN_BIN_FRACTION = 1e-2 # each bin gets at least 1% of the total width/height
MIN_DERIVATIVE = 1e-2
def _prepare_spline_params(widths, heights, derivatives, tail_bound):
"""Shared preprocessing for forward and inverse spline transforms.
Enforces minimum bin width/height (following nflows convention) to prevent
degenerate near-step-function splines that break invertibility.
"""
K = widths.shape[-1]
total = 2 * tail_bound
widths = F.softmax(widths, dim=-1)
widths = MIN_BIN_FRACTION + (1 - K * MIN_BIN_FRACTION) * widths
widths = widths * total
heights = F.softmax(heights, dim=-1)
heights = MIN_BIN_FRACTION + (1 - K * MIN_BIN_FRACTION) * heights
heights = heights * total
derivatives = F.softplus(derivatives) + MIN_DERIVATIVE
cumwidths = torch.cumsum(widths, dim=-1)
cumwidths = F.pad(cumwidths, (1, 0), value=0.0) - tail_bound
cumheights = torch.cumsum(heights, dim=-1)
cumheights = F.pad(cumheights, (1, 0), value=0.0) - tail_bound
return widths, heights, derivatives, cumwidths, cumheights
def rational_quadratic_spline_forward(x, widths, heights, derivatives, tail_bound=5.0):
"""
Apply monotonic rational-quadratic spline transform (forward direction).
Uses identity transform outside [-tail_bound, tail_bound].
Args:
x: [B, D] input values
widths: [B, D, K] bin widths (pre-softmax)
heights: [B, D, K] bin heights (pre-softmax)
derivatives: [B, D, K+1] knot derivatives (pre-softplus)
Returns:
y: [B, D] transformed values
log_det: [B, D] log absolute derivative per dimension
"""
K = widths.shape[-1]
widths, heights, derivatives, cumwidths, cumheights = \
_prepare_spline_params(widths, heights, derivatives, tail_bound)
inside = (x >= -tail_bound) & (x <= tail_bound)
# Default: identity (for tails)
y = x.clone()
log_det = torch.zeros_like(x)
if not inside.any():
return y, log_det
x_in = x[inside]
# Bin lookup on the interior cumwidths
# Flatten the relevant cumwidths for searchsorted
cw_in = cumwidths[inside] # [N_inside, K+1]
bin_idx = torch.searchsorted(cw_in[:, 1:].contiguous(), x_in.unsqueeze(-1)).squeeze(-1)
bin_idx = bin_idx.clamp(0, K - 1)
idx = bin_idx.unsqueeze(-1)
w_k = widths[inside].gather(-1, idx).squeeze(-1)
h_k = heights[inside].gather(-1, idx).squeeze(-1)
d_k = derivatives[inside].gather(-1, idx).squeeze(-1)
d_k1 = derivatives[inside].gather(-1, idx + 1).squeeze(-1)
cw_k = cw_in.gather(-1, idx).squeeze(-1)
ch_k = cumheights[inside].gather(-1, idx).squeeze(-1)
xi = (x_in - cw_k) / w_k
xi = xi.clamp(1e-6, 1.0 - 1e-6)
s_k = h_k / w_k
numer = h_k * (s_k * xi.pow(2) + d_k * xi * (1 - xi))
denom = s_k + (d_k + d_k1 - 2 * s_k) * xi * (1 - xi)
y[inside] = ch_k + numer / denom
deriv_numer = s_k.pow(2) * (d_k1 * xi.pow(2) + 2 * s_k * xi * (1 - xi) + d_k * (1 - xi).pow(2))
log_det[inside] = torch.log(deriv_numer.clamp(min=1e-8)) - 2 * torch.log(denom.abs().clamp(min=1e-8))
log_det = log_det.clamp(-20.0, 20.0)
return y, log_det
def rational_quadratic_spline_inverse(y, widths, heights, derivatives, tail_bound=5.0):
"""
Apply inverse rational-quadratic spline transform.
Uses identity transform outside [-tail_bound, tail_bound].
"""
K = widths.shape[-1]
widths, heights, derivatives, cumwidths, cumheights = \
_prepare_spline_params(widths, heights, derivatives, tail_bound)
inside = (y >= -tail_bound) & (y <= tail_bound)
x = y.clone()
log_det = torch.zeros_like(y)
if not inside.any():
return x, log_det
y_in = y[inside]
ch_in = cumheights[inside]
bin_idx = torch.searchsorted(ch_in[:, 1:].contiguous(), y_in.unsqueeze(-1)).squeeze(-1)
bin_idx = bin_idx.clamp(0, K - 1)
idx = bin_idx.unsqueeze(-1)
w_k = widths[inside].gather(-1, idx).squeeze(-1)
h_k = heights[inside].gather(-1, idx).squeeze(-1)
d_k = derivatives[inside].gather(-1, idx).squeeze(-1)
d_k1 = derivatives[inside].gather(-1, idx + 1).squeeze(-1)
cw_k = cumwidths[inside].gather(-1, idx).squeeze(-1)
ch_k = ch_in.gather(-1, idx).squeeze(-1)
s_k = h_k / w_k
y_rel = y_in - ch_k
a = h_k * (s_k - d_k) + y_rel * (d_k + d_k1 - 2 * s_k)
b = h_k * d_k - y_rel * (d_k + d_k1 - 2 * s_k)
c = -s_k * y_rel
discriminant = b.pow(2) - 4 * a * c
sqrt_disc = torch.sqrt(discriminant.clamp(min=1e-8))
xi = (2 * c) / (-b - sqrt_disc).clamp(max=-1e-8)
xi = xi.clamp(1e-6, 1.0 - 1e-6)
x[inside] = cw_k + w_k * xi
deriv_numer = s_k.pow(2) * (d_k1 * xi.pow(2) + 2 * s_k * xi * (1 - xi) + d_k * (1 - xi).pow(2))
denom = s_k + (d_k + d_k1 - 2 * s_k) * xi * (1 - xi)
log_det[inside] = torch.log(deriv_numer.clamp(min=1e-8)) - 2 * torch.log(denom.abs().clamp(min=1e-8))
log_det = log_det.clamp(-20.0, 20.0)
return x, -log_det
class ConditionalSplineCoupling(nn.Module):
"""
Conditional Neural Spline coupling layer (Durkan et al. 2019).
Same interface as ConditionalAffineCoupling but uses rational-quadratic
splines instead of affine transforms. Much more expressive per layer.
Args:
dim: parameter dimension
d_context: context vector dimension
hidden_dim: hidden layer size in conditioner network
mask_type: 'even' or 'odd'
n_bins: number of spline bins (K)
tail_bound: spline domain [-B, B]
"""
def __init__(self, dim, d_context, hidden_dim=128, mask_type='even',
n_bins=8, tail_bound=5.0):
super().__init__()
self.dim = dim
self.n_bins = n_bins
self.tail_bound = tail_bound
if mask_type == 'even':
self.register_buffer('mask', torch.arange(dim) % 2 == 0)
else:
self.register_buffer('mask', torch.arange(dim) % 2 == 1)
n_a = self.mask.sum().item()
n_b = dim - n_a
self.n_b = n_b
# Output: K widths + K heights + (K+1) derivatives per transformed dim
n_out = n_b * (3 * n_bins + 1)
self.net = nn.Sequential(
nn.Linear(n_a + d_context, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, n_out),
)
nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)
def _get_spline_params(self, z_a, context):
"""Compute spline parameters from conditioner network."""
raw = self.net(torch.cat([z_a, context], dim=-1)) # [B, n_b*(3K+1)]
raw = raw.view(-1, self.n_b, 3 * self.n_bins + 1)
K = self.n_bins
widths = raw[..., :K]
heights = raw[..., K:2*K]
derivatives = raw[..., 2*K:]
# Bound raw outputs to prevent extreme spline configurations.
# softmax(widths/heights) is shift-invariant so bounding doesn't
# reduce expressiveness, it just prevents near-one-hot bin allocations.
# softplus(derivatives) with bounded input caps the knot slopes.
widths = widths.clamp(-5.0, 5.0)
heights = heights.clamp(-5.0, 5.0)
derivatives = derivatives.clamp(-5.0, 5.0)
return widths, heights, derivatives
def forward(self, z, context):
z_a = z[:, self.mask]
z_b = z[:, ~self.mask]
widths, heights, derivatives = self._get_spline_params(z_a, context)
theta_b, log_det_per_dim = rational_quadratic_spline_forward(
z_b, widths, heights, derivatives, self.tail_bound)
log_det = log_det_per_dim.sum(dim=-1)
theta = torch.empty_like(z)
theta[:, self.mask] = z_a
theta[:, ~self.mask] = theta_b
return theta, log_det
def inverse(self, theta, context):
theta_a = theta[:, self.mask]
theta_b = theta[:, ~self.mask]
widths, heights, derivatives = self._get_spline_params(theta_a, context)
z_b, log_det_per_dim = rational_quadratic_spline_inverse(
theta_b, widths, heights, derivatives, self.tail_bound)
log_det = log_det_per_dim.sum(dim=-1)
z = torch.empty_like(theta)
z[:, self.mask] = theta_a
z[:, ~self.mask] = z_b
return z, log_det
# =============================================================================
# Full Conditional Normalizing Flow
# =============================================================================
class ConditionalFlow(nn.Module):
"""
Conditional normalizing flow for p(theta | x).
Maps between base distribution z ~ N(0, I) and parameter space theta,
conditioned on the observed signal x.
Supports both affine and spline coupling layers.
"""
def __init__(
self,
theta_dim=3,
d_context=128,
n_coupling_layers=8,
hidden_dim=128,
d_model=128,
coupling_type='spline',
n_bins=8,
tail_bound=5.0,
):
super().__init__()
self.theta_dim = theta_dim
self.d_context = d_context
self.coupling_type = coupling_type
self.tail_bound = tail_bound
self.encoder = SignalEncoder(
in_channels=3, d_model=d_model, d_context=d_context
)
self.flows = nn.ModuleList()
for i in range(n_coupling_layers):
mask_type = 'even' if i % 2 == 0 else 'odd'
self.flows.append(ActNorm(theta_dim))
if coupling_type == 'spline':
self.flows.append(
ConditionalSplineCoupling(
dim=theta_dim,
d_context=d_context,
hidden_dim=hidden_dim,
mask_type=mask_type,
n_bins=n_bins,
tail_bound=tail_bound,
)
)
else:
self.flows.append(
ConditionalAffineCoupling(
dim=theta_dim,
d_context=d_context,
hidden_dim=hidden_dim,
mask_type=mask_type,
)
)
self.register_buffer('theta_mean', torch.zeros(theta_dim))
self.register_buffer('theta_std', torch.ones(theta_dim))
def set_theta_stats(self, mean, std):
self.theta_mean.copy_(torch.as_tensor(mean, dtype=torch.float32))
self.theta_std.copy_(torch.as_tensor(std, dtype=torch.float32))
def normalize_theta(self, theta):
return (theta - self.theta_mean) / self.theta_std
def denormalize_theta(self, theta_norm):
return theta_norm * self.theta_std + self.theta_mean
def encode_signal(self, x, mask=None):
return self.encoder(x, mask=mask)
def forward_flow(self, z, context):
total_log_det = torch.zeros(z.shape[0], device=z.device)
h = z
for layer in self.flows:
if isinstance(layer, ActNorm):
h, ld = layer(h)
total_log_det += ld
else:
h, ld = layer(h, context)
total_log_det += ld
return h, total_log_det
def inverse_flow(self, theta_norm, context):
total_log_det = torch.zeros(theta_norm.shape[0], device=theta_norm.device)
h = theta_norm
for layer in reversed(self.flows):
if isinstance(layer, ActNorm):
h = layer.inverse(h)
total_log_det -= layer.log_scale.sum()
else:
h, ld = layer.inverse(h, context)
total_log_det += ld
return h, total_log_det
def log_prob(self, theta, x, mask=None):
context = self.encode_signal(x, mask=mask)
theta_norm = self.normalize_theta(theta)
theta_norm = theta_norm.clamp(-self.tail_bound, self.tail_bound)
z, log_det = self.inverse_flow(theta_norm, context)
log_pz = -0.5 * (z ** 2 + math.log(2 * math.pi)).sum(dim=-1)
log_det_norm = -torch.log(self.theta_std).sum()
log_p = log_pz + log_det + log_det_norm
return log_p.clamp(min=-50.0, max=50.0)
@torch.no_grad()
def sample(self, x, mask=None, n_samples=100):
B = x.shape[0]
context = self.encode_signal(x, mask=mask)
context_rep = context.unsqueeze(1).expand(-1, n_samples, -1)
context_rep = context_rep.reshape(B * n_samples, -1)
z = torch.randn(B * n_samples, self.theta_dim, device=x.device)
theta_norm, _ = self.forward_flow(z, context_rep)
theta = self.denormalize_theta(theta_norm)
return theta.reshape(B, n_samples, self.theta_dim)
@torch.no_grad()
def posterior_stats(self, x, mask=None, n_samples=1000):
samples = self.sample(x, mask=mask, n_samples=n_samples)
return {
'mean': samples.mean(dim=1),
'std': samples.std(dim=1),
'median': samples.median(dim=1).values,
'q05': samples.quantile(0.05, dim=1),
'q95': samples.quantile(0.95, dim=1),
'samples': samples,
}
# Backward-compatible alias
ConditionalRealNVP = ConditionalFlow
# =============================================================================
# Helper
# =============================================================================
THETA_NAMES = ['log10(K0)', 'alpha', 'log10(dB)']
# Per-mechanism parameter definitions (variable dim per mechanism)
MECHANISM_PARAMS = {
'Nernst': {
'names': ['E0_offset', 'log10(dA)', 'log10(dB)'],
'dim': 3,
},
'BV': {
'names': ['log10(K0)', 'alpha', 'log10(dB)'],
'dim': 3,
},
'MHC': {
'names': ['log10(K0)', 'log10(reorg_e)', 'log10(dB)'],
'dim': 3,
},
'Ads': {
'names': ['log10(K0)', 'alpha', 'log10(Gamma_sat)'],
'dim': 3,
},
'EC': {
'names': ['log10(K0)', 'alpha', 'log10(kc)', 'log10(dB)'],
'dim': 4,
},
'LH': {
'names': ['log10(K0)', 'alpha', 'log10(KA_eq)', 'log10(KB_eq)', 'log10(dB)'],
'dim': 5,
},
}
MECHANISM_LIST = ['Nernst', 'BV', 'MHC', 'Ads', 'EC', 'LH']
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
B, T = 4, 800
theta_dim = 3
x = torch.randn(B, 3, T)
theta = torch.randn(B, theta_dim)
mask = torch.ones(B, T, dtype=torch.bool)
for coupling_type in ['affine', 'spline']:
print(f"\n{'=' * 50}")
print(f"Testing ConditionalFlow (coupling={coupling_type})")
print(f"{'=' * 50}")
model = ConditionalFlow(
theta_dim=theta_dim,
d_context=128,
n_coupling_layers=8,
hidden_dim=128,
d_model=128,
coupling_type=coupling_type,
)
print(f"Parameters: {count_parameters(model):,}")
log_q = model.log_prob(theta, x, mask=mask)
print(f"log_prob shape: {log_q.shape}, values: {log_q}")
samples = model.sample(x, mask=mask, n_samples=100)
print(f"Samples shape: {samples.shape}")
print(f"Sample mean: {samples.mean(dim=1)}")
print(f"Sample std: {samples.std(dim=1)}")