krystv commited on
Commit
714db4b
·
verified ·
1 Parent(s): 34bede7

Upload liquid_flow/cfc_cell.py

Browse files
Files changed (1) hide show
  1. liquid_flow/cfc_cell.py +169 -0
liquid_flow/cfc_cell.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CfC Cell — Closed-form Continuous-time neural network cell.
3
+
4
+ From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)
5
+
6
+ The CfC model provides an approximate closed-form solution to Liquid Time-Constant (LTC)
7
+ network dynamics without needing ODE solvers.
8
+
9
+ Architecture:
10
+ x(t) = σ(-f(x,I;θ_f) · t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f) · t)) ⊙ h(x,I;θ_h)
11
+
12
+ Where:
13
+ - f, g, h are neural network heads sharing a backbone
14
+ - σ is the sigmoid (replacing exponential decay for gradient stability)
15
+ - t is a time parameter
16
+ - The sigmoidal terms act as time-continuous gates between g and h
17
+
18
+ Key properties:
19
+ - No ODE solving → 100x+ faster than Neural ODEs
20
+ - Time-continuous gating mechanism → adaptive computation
21
+ - Closed-form → stable gradients, easy to train
22
+ - Naturally causal → good for sequential processing
23
+
24
+ For 2D image inputs: we treat the spatial sequence as "time" steps for the CfC,
25
+ allowing the liquid dynamics to model spatial dependencies with adaptive gates.
26
+ """
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+
33
+ class CfCCell(nn.Module):
34
+ """
35
+ Single CfC cell with backbone + 3 heads (f, g, h).
36
+
37
+ Args:
38
+ dim: Hidden dimension
39
+ backbone_dropout: Dropout in backbone layers
40
+ time_scale: Range [a, b] for time parameter sampling
41
+ use_conv: Add conv1d for local context
42
+ """
43
+
44
+ def __init__(self, dim, backbone_dropout=0.0, time_scale=(0.0, 1.0), use_conv=True):
45
+ super().__init__()
46
+ self.dim = dim
47
+ self.time_scale = time_scale
48
+
49
+ # Shared backbone
50
+ backbone_dim = dim * 3
51
+ self.backbone = nn.Sequential(
52
+ nn.Linear(dim + dim, backbone_dim),
53
+ nn.LayerNorm(backbone_dim),
54
+ nn.SiLU(),
55
+ nn.Dropout(backbone_dropout),
56
+ nn.Linear(backbone_dim, dim * 4),
57
+ nn.LayerNorm(dim * 4),
58
+ )
59
+
60
+ # Optional 1D conv
61
+ self.conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim) if use_conv else None
62
+
63
+ # Heads
64
+ self.f_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.Tanh())
65
+ self.g_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
66
+ self.h_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
67
+
68
+ self.out_proj = nn.Linear(dim, dim)
69
+ self._init_weights()
70
+
71
+ def _init_weights(self):
72
+ for m in self.modules():
73
+ if isinstance(m, nn.Linear):
74
+ nn.init.normal_(m.weight, std=0.02)
75
+ if m.bias is not None:
76
+ nn.init.zeros_(m.bias)
77
+
78
+ def forward(self, x, h_prev=None, t=None):
79
+ """
80
+ Args:
81
+ x: [B, dim] or [B, L, dim]
82
+ h_prev: Previous hidden state [B, dim]
83
+ t: Time parameter
84
+ Returns: h: [B, dim] or [B, L, dim]
85
+ """
86
+ is_seq = x.dim() == 3
87
+ B, device = x.shape[0], x.device
88
+
89
+ if is_seq:
90
+ return self._forward_seq(x, h_prev, t)
91
+
92
+ if h_prev is None:
93
+ h_prev = torch.zeros(B, self.dim, device=device)
94
+ if t is None:
95
+ t = torch.rand(B, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
96
+ elif t.dim() == 1:
97
+ t = t.unsqueeze(1)
98
+
99
+ return self._step(x, h_prev, t)
100
+
101
+ def _forward_seq(self, x, h_prev=None, t=None):
102
+ B, L, D = x.shape
103
+ device = x.device
104
+
105
+ if t is None:
106
+ t = torch.rand(B, 1, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
107
+
108
+ outputs = []
109
+ h = torch.zeros(B, D, device=device) if h_prev is None else h_prev
110
+ for step in range(L):
111
+ h = self._step(x[:, step, :], h, t.squeeze(-1) if t.dim() == 3 else t)
112
+ outputs.append(h)
113
+ return torch.stack(outputs, dim=1)
114
+
115
+ def _step(self, x, h_prev, t):
116
+ """Core CfC step."""
117
+ combined = torch.cat([x, h_prev], dim=-1)
118
+ backbone_out = self.backbone(combined)
119
+ f_base, g_base, h_base, skip = backbone_out.chunk(4, dim=-1)
120
+
121
+ if self.conv is not None:
122
+ f_base = f_base + self.conv(f_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
123
+ g_base = g_base + self.conv(g_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
124
+ h_base = h_base + self.conv(h_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
125
+
126
+ f_out = self.f_head(f_base)
127
+ g_out = self.g_head(g_base)
128
+ h_out = self.h_head(h_base)
129
+
130
+ gate = torch.sigmoid(-f_out * t)
131
+ h = gate * g_out + (1 - gate) * h_out + skip
132
+ return self.out_proj(h)
133
+
134
+
135
+ class CfCBlock(nn.Module):
136
+ """CfC block for 2D image processing with residual connection."""
137
+
138
+ def __init__(self, dim, dropout=0.0, time_scale=(0.0, 1.0), expansion_factor=2):
139
+ super().__init__()
140
+ self.dim = dim
141
+ self.norm1 = nn.LayerNorm(dim)
142
+ self.norm2 = nn.LayerNorm(dim)
143
+ self.cfc = CfCCell(dim=dim, backbone_dropout=dropout, time_scale=time_scale, use_conv=True)
144
+
145
+ ff_dim = dim * expansion_factor
146
+ self.ff = nn.Sequential(
147
+ nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout),
148
+ nn.Linear(ff_dim, dim), nn.Dropout(dropout),
149
+ )
150
+
151
+ self.pos_embed = nn.Parameter(torch.randn(1, 4096, dim) * 0.02)
152
+
153
+ def forward(self, x, return_2d=True):
154
+ is_2d = x.dim() == 4
155
+ if is_2d:
156
+ B, C, H, W = x.shape
157
+ L = H * W
158
+ x = x.flatten(2).transpose(1, 2)
159
+ else:
160
+ B, L, C = x.shape
161
+
162
+ x_with_pos = x + self.pos_embed[:, :L, :]
163
+ residual = x
164
+ h = self.cfc(self.norm1(x_with_pos))
165
+ x_out = h + self.ff(self.norm2(h + residual))
166
+
167
+ if is_2d and return_2d:
168
+ x_out = x_out.transpose(1, 2).reshape(B, C, H, W)
169
+ return x_out