File size: 6,075 Bytes
714db4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
CfC Cell — Closed-form Continuous-time neural network cell.

From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)

The CfC model provides an approximate closed-form solution to Liquid Time-Constant (LTC)
network dynamics without needing ODE solvers.

Architecture:
    x(t) = σ(-f(x,I;θ_f) · t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f) · t)) ⊙ h(x,I;θ_h)

Where:
    - f, g, h are neural network heads sharing a backbone
    - σ is the sigmoid (replacing exponential decay for gradient stability)  
    - t is a time parameter
    - The sigmoidal terms act as time-continuous gates between g and h

Key properties:
    - No ODE solving → 100x+ faster than Neural ODEs
    - Time-continuous gating mechanism → adaptive computation
    - Closed-form → stable gradients, easy to train
    - Naturally causal → good for sequential processing

For 2D image inputs: we treat the spatial sequence as "time" steps for the CfC,
allowing the liquid dynamics to model spatial dependencies with adaptive gates.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class CfCCell(nn.Module):
    """
    Single CfC cell with backbone + 3 heads (f, g, h).
    
    Args:
        dim: Hidden dimension
        backbone_dropout: Dropout in backbone layers
        time_scale: Range [a, b] for time parameter sampling
        use_conv: Add conv1d for local context
    """
    
    def __init__(self, dim, backbone_dropout=0.0, time_scale=(0.0, 1.0), use_conv=True):
        super().__init__()
        self.dim = dim
        self.time_scale = time_scale
        
        # Shared backbone
        backbone_dim = dim * 3
        self.backbone = nn.Sequential(
            nn.Linear(dim + dim, backbone_dim),
            nn.LayerNorm(backbone_dim),
            nn.SiLU(),
            nn.Dropout(backbone_dropout),
            nn.Linear(backbone_dim, dim * 4),
            nn.LayerNorm(dim * 4),
        )
        
        # Optional 1D conv
        self.conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim) if use_conv else None
        
        # Heads
        self.f_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.Tanh())
        self.g_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
        self.h_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU())
        
        self.out_proj = nn.Linear(dim, dim)
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x, h_prev=None, t=None):
        """
        Args:
            x: [B, dim] or [B, L, dim]
            h_prev: Previous hidden state [B, dim]
            t: Time parameter
        Returns: h: [B, dim] or [B, L, dim]
        """
        is_seq = x.dim() == 3
        B, device = x.shape[0], x.device
        
        if is_seq:
            return self._forward_seq(x, h_prev, t)
        
        if h_prev is None:
            h_prev = torch.zeros(B, self.dim, device=device)
        if t is None:
            t = torch.rand(B, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
        elif t.dim() == 1:
            t = t.unsqueeze(1)
        
        return self._step(x, h_prev, t)
    
    def _forward_seq(self, x, h_prev=None, t=None):
        B, L, D = x.shape
        device = x.device
        
        if t is None:
            t = torch.rand(B, 1, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0]
        
        outputs = []
        h = torch.zeros(B, D, device=device) if h_prev is None else h_prev
        for step in range(L):
            h = self._step(x[:, step, :], h, t.squeeze(-1) if t.dim() == 3 else t)
            outputs.append(h)
        return torch.stack(outputs, dim=1)
    
    def _step(self, x, h_prev, t):
        """Core CfC step."""
        combined = torch.cat([x, h_prev], dim=-1)
        backbone_out = self.backbone(combined)
        f_base, g_base, h_base, skip = backbone_out.chunk(4, dim=-1)
        
        if self.conv is not None:
            f_base = f_base + self.conv(f_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
            g_base = g_base + self.conv(g_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
            h_base = h_base + self.conv(h_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1)
        
        f_out = self.f_head(f_base)
        g_out = self.g_head(g_base)
        h_out = self.h_head(h_base)
        
        gate = torch.sigmoid(-f_out * t)
        h = gate * g_out + (1 - gate) * h_out + skip
        return self.out_proj(h)


class CfCBlock(nn.Module):
    """CfC block for 2D image processing with residual connection."""
    
    def __init__(self, dim, dropout=0.0, time_scale=(0.0, 1.0), expansion_factor=2):
        super().__init__()
        self.dim = dim
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.cfc = CfCCell(dim=dim, backbone_dropout=dropout, time_scale=time_scale, use_conv=True)
        
        ff_dim = dim * expansion_factor
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(ff_dim, dim), nn.Dropout(dropout),
        )
        
        self.pos_embed = nn.Parameter(torch.randn(1, 4096, dim) * 0.02)
    
    def forward(self, x, return_2d=True):
        is_2d = x.dim() == 4
        if is_2d:
            B, C, H, W = x.shape
            L = H * W
            x = x.flatten(2).transpose(1, 2)
        else:
            B, L, C = x.shape
        
        x_with_pos = x + self.pos_embed[:, :L, :]
        residual = x
        h = self.cfc(self.norm1(x_with_pos))
        x_out = h + self.ff(self.norm2(h + residual))
        
        if is_2d and return_2d:
            x_out = x_out.transpose(1, 2).reshape(B, C, H, W)
        return x_out