""" CfC Cell — Closed-form Continuous-time neural network cell. From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022) The CfC model provides an approximate closed-form solution to Liquid Time-Constant (LTC) network dynamics without needing ODE solvers. Architecture: x(t) = σ(-f(x,I;θ_f) · t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f) · t)) ⊙ h(x,I;θ_h) Where: - f, g, h are neural network heads sharing a backbone - σ is the sigmoid (replacing exponential decay for gradient stability) - t is a time parameter - The sigmoidal terms act as time-continuous gates between g and h Key properties: - No ODE solving → 100x+ faster than Neural ODEs - Time-continuous gating mechanism → adaptive computation - Closed-form → stable gradients, easy to train - Naturally causal → good for sequential processing For 2D image inputs: we treat the spatial sequence as "time" steps for the CfC, allowing the liquid dynamics to model spatial dependencies with adaptive gates. """ import torch import torch.nn as nn import torch.nn.functional as F class CfCCell(nn.Module): """ Single CfC cell with backbone + 3 heads (f, g, h). Args: dim: Hidden dimension backbone_dropout: Dropout in backbone layers time_scale: Range [a, b] for time parameter sampling use_conv: Add conv1d for local context """ def __init__(self, dim, backbone_dropout=0.0, time_scale=(0.0, 1.0), use_conv=True): super().__init__() self.dim = dim self.time_scale = time_scale # Shared backbone backbone_dim = dim * 3 self.backbone = nn.Sequential( nn.Linear(dim + dim, backbone_dim), nn.LayerNorm(backbone_dim), nn.SiLU(), nn.Dropout(backbone_dropout), nn.Linear(backbone_dim, dim * 4), nn.LayerNorm(dim * 4), ) # Optional 1D conv self.conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim) if use_conv else None # Heads self.f_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.Tanh()) self.g_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU()) self.h_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU()) self.out_proj = nn.Linear(dim, dim) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x, h_prev=None, t=None): """ Args: x: [B, dim] or [B, L, dim] h_prev: Previous hidden state [B, dim] t: Time parameter Returns: h: [B, dim] or [B, L, dim] """ is_seq = x.dim() == 3 B, device = x.shape[0], x.device if is_seq: return self._forward_seq(x, h_prev, t) if h_prev is None: h_prev = torch.zeros(B, self.dim, device=device) if t is None: t = torch.rand(B, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0] elif t.dim() == 1: t = t.unsqueeze(1) return self._step(x, h_prev, t) def _forward_seq(self, x, h_prev=None, t=None): B, L, D = x.shape device = x.device if t is None: t = torch.rand(B, 1, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0] outputs = [] h = torch.zeros(B, D, device=device) if h_prev is None else h_prev for step in range(L): h = self._step(x[:, step, :], h, t.squeeze(-1) if t.dim() == 3 else t) outputs.append(h) return torch.stack(outputs, dim=1) def _step(self, x, h_prev, t): """Core CfC step.""" combined = torch.cat([x, h_prev], dim=-1) backbone_out = self.backbone(combined) f_base, g_base, h_base, skip = backbone_out.chunk(4, dim=-1) if self.conv is not None: f_base = f_base + self.conv(f_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1) g_base = g_base + self.conv(g_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1) h_base = h_base + self.conv(h_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1) f_out = self.f_head(f_base) g_out = self.g_head(g_base) h_out = self.h_head(h_base) gate = torch.sigmoid(-f_out * t) h = gate * g_out + (1 - gate) * h_out + skip return self.out_proj(h) class CfCBlock(nn.Module): """CfC block for 2D image processing with residual connection.""" def __init__(self, dim, dropout=0.0, time_scale=(0.0, 1.0), expansion_factor=2): super().__init__() self.dim = dim self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.cfc = CfCCell(dim=dim, backbone_dropout=dropout, time_scale=time_scale, use_conv=True) ff_dim = dim * expansion_factor self.ff = nn.Sequential( nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, dim), nn.Dropout(dropout), ) self.pos_embed = nn.Parameter(torch.randn(1, 4096, dim) * 0.02) def forward(self, x, return_2d=True): is_2d = x.dim() == 4 if is_2d: B, C, H, W = x.shape L = H * W x = x.flatten(2).transpose(1, 2) else: B, L, C = x.shape x_with_pos = x + self.pos_embed[:, :L, :] residual = x h = self.cfc(self.norm1(x_with_pos)) x_out = h + self.ff(self.norm2(h + residual)) if is_2d and return_2d: x_out = x_out.transpose(1, 2).reshape(B, C, H, W) return x_out