File size: 6,075 Bytes
714db4b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
CfC Cell — Closed-form Continuous-time neural network cell.
From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)
The CfC model provides an approximate closed-form solution to Liquid Time-Constant (LTC)
network dynamics without needing ODE solvers.
Architecture:
x(t) = σ(-f(x,I;θ_f) · t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f) · t)) ⊙ h(x,I;θ_h)
Where:
- f, g, h are neural network heads sharing a backbone
- σ is the sigmoid (replacing exponential decay for gradient stability)
- t is a time parameter
- The sigmoidal terms act as time-continuous gates between g and h
Key properties:
- No ODE solving → 100x+ faster than Neural ODEs
- Time-continuous gating mechanism → adaptive computation
- Closed-form → stable gradients, easy to train
- Naturally causal → good for sequential processing
For 2D image inputs: we treat the spatial sequence as "time" steps for the CfC,
allowing the liquid dynamics to model spatial dependencies with adaptive gates.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class CfCCell(nn.Module):
"""
Single CfC cell with backbone + 3 heads (f, g, h).
Args:
dim: Hidden dimension
backbone_dropout: Dropout in backbone layers
time_scale: Range [a, b] for time parameter sampling
use_conv: Add conv1d for local context
"""
def __init__(self, dim, backbone_dropout=0.0, time_scale=(0.0, 1.0), use_conv=True):
super().__init__()
self.dim = dim
self.time_scale = time_scale
# Shared backbone
backbone_dim = dim * 3
self.backbone = nn.Sequential(
nn.Linear(dim + dim, backbone_dim),
nn.LayerNorm(backbone_dim),
nn.SiLU(),
nn.Dropout(backbone_dropout),
nn.Linear(backbone_dim, dim * 4),
nn.LayerNorm(dim * 4),
)
# Optional 1D conv
self.conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim) if use_conv else None
# Heads
self.f_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.Tanh())
self.g_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
self.h_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
self.out_proj = nn.Linear(dim, dim)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x, h_prev=None, t=None):
"""
Args:
x: [B, dim] or [B, L, dim]
h_prev: Previous hidden state [B, dim]
t: Time parameter
Returns: h: [B, dim] or [B, L, dim]
"""
is_seq = x.dim() == 3
B, device = x.shape[0], x.device
if is_seq:
return self._forward_seq(x, h_prev, t)
if h_prev is None:
h_prev = torch.zeros(B, self.dim, device=device)
if t is None:
t = torch.rand(B, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
elif t.dim() == 1:
t = t.unsqueeze(1)
return self._step(x, h_prev, t)
def _forward_seq(self, x, h_prev=None, t=None):
B, L, D = x.shape
device = x.device
if t is None:
t = torch.rand(B, 1, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
outputs = []
h = torch.zeros(B, D, device=device) if h_prev is None else h_prev
for step in range(L):
h = self._step(x[:, step, :], h, t.squeeze(-1) if t.dim() == 3 else t)
outputs.append(h)
return torch.stack(outputs, dim=1)
def _step(self, x, h_prev, t):
"""Core CfC step."""
combined = torch.cat([x, h_prev], dim=-1)
backbone_out = self.backbone(combined)
f_base, g_base, h_base, skip = backbone_out.chunk(4, dim=-1)
if self.conv is not None:
f_base = f_base + self.conv(f_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
g_base = g_base + self.conv(g_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
h_base = h_base + self.conv(h_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
f_out = self.f_head(f_base)
g_out = self.g_head(g_base)
h_out = self.h_head(h_base)
gate = torch.sigmoid(-f_out * t)
h = gate * g_out + (1 - gate) * h_out + skip
return self.out_proj(h)
class CfCBlock(nn.Module):
"""CfC block for 2D image processing with residual connection."""
def __init__(self, dim, dropout=0.0, time_scale=(0.0, 1.0), expansion_factor=2):
super().__init__()
self.dim = dim
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.cfc = CfCCell(dim=dim, backbone_dropout=dropout, time_scale=time_scale, use_conv=True)
ff_dim = dim * expansion_factor
self.ff = nn.Sequential(
nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, dim), nn.Dropout(dropout),
)
self.pos_embed = nn.Parameter(torch.randn(1, 4096, dim) * 0.02)
def forward(self, x, return_2d=True):
is_2d = x.dim() == 4
if is_2d:
B, C, H, W = x.shape
L = H * W
x = x.flatten(2).transpose(1, 2)
else:
B, L, C = x.shape
x_with_pos = x + self.pos_embed[:, :L, :]
residual = x
h = self.cfc(self.norm1(x_with_pos))
x_out = h + self.ff(self.norm2(h + residual))
if is_2d and return_2d:
x_out = x_out.transpose(1, 2).reshape(B, C, H, W)
return x_out
|