asdf98 commited on
Commit
0b30f6b
·
verified ·
1 Parent(s): a1e6c8b

Upload liqmamba/cfc.py

Browse files
Files changed (1) hide show
  1. liqmamba/cfc.py +164 -0
liqmamba/cfc.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Liquid CfC Cell — Closed-form Continuous-time Neural Network
3
+
4
+ Implements Theorem 1 from Hasani et al. (2021): an approximate closed-form
5
+ solution for Liquid Time-Constant (LTC) networks.
6
+
7
+ The CfC cell computes:
8
+ x(t) = (x(0) - A) * exp(-[w_tau + f(I, theta)] * t) * f(-I, theta) + A
9
+
10
+ This gives us O(1) computation (no ODE solver needed) while preserving the
11
+ expressive continuous-time dynamics of LTCs.
12
+
13
+ In our architecture, CfC replaces static activation functions (SiLU/GELU)
14
+ with learnable temporal dynamics, giving each token adaptive computation depth.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ class CfCCell(nn.Module):
23
+ """
24
+ Single CfC cell implementing the closed-form solution.
25
+
26
+ Args:
27
+ dim: Input/output dimension
28
+ hidden_dim: Hidden dimension for the backbone network f
29
+ time_constant_init: Initial value for the time constant w_tau
30
+ bias_init: Initial value for the bias vector A
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ dim: int,
36
+ hidden_dim: int | None = None,
37
+ time_constant_init: float = 1.0,
38
+ bias_init: float = 0.0,
39
+ ):
40
+ super().__init__()
41
+ hidden_dim = hidden_dim or dim * 2
42
+
43
+ # Backbone network f(x, I, theta): maps input to activation
44
+ self.backbone = nn.Sequential(
45
+ nn.Linear(dim, hidden_dim),
46
+ nn.Tanh(),
47
+ nn.Linear(hidden_dim, dim),
48
+ )
49
+
50
+ # Time constant w_tau — controls how fast the state decays
51
+ self.w_tau = nn.Parameter(torch.full((dim,), time_constant_init))
52
+
53
+ # Bias vector A — steady-state target
54
+ self.A = nn.Parameter(torch.full((dim,), bias_init))
55
+
56
+ self.dim = dim
57
+
58
+ def forward(self, x: torch.Tensor, t: float | torch.Tensor = 1.0) -> torch.Tensor:
59
+ """
60
+ Args:
61
+ x: Input tensor of shape (..., dim)
62
+ t: Time delta (scalar or tensor matching batch dims)
63
+ Returns:
64
+ Updated state of shape (..., dim)
65
+ """
66
+ # f(x, theta) — nonlinear transformation
67
+ fx = self.backbone(x)
68
+
69
+ # w_tau + f(x, theta): effective decay rate
70
+ # Softplus keeps w_tau positive (biological plausibility)
71
+ decay = F.softplus(self.w_tau) + fx
72
+
73
+ # exp(-decay * t): temporal decay factor
74
+ if isinstance(t, (int, float)):
75
+ exp_term = torch.exp(-decay * t)
76
+ else:
77
+ exp_term = torch.exp(-decay * t.unsqueeze(-1))
78
+
79
+ # Closed-form solution: x(t) = (x(0) - A) * exp(-decay*t) + A
80
+ # With residual: x acts as x(0) - A, then we add back A
81
+ state = (x - self.A) * exp_term + self.A
82
+
83
+ return state
84
+
85
+
86
+ class CfCLayer(nn.Module):
87
+ """
88
+ Multi-neuron CfC layer with optional wiring (sparse connectivity).
89
+ Each neuron has its own CfC cell, and they interact through a
90
+ linear projection before the CfC update.
91
+
92
+ This implements the full CfC model from the paper where:
93
+ - Each neuron has independent time constants
94
+ - Neurons interact through a learned weight matrix
95
+ - The closed-form solution gives O(1) per-step computation
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ dim: int,
101
+ expansion_factor: int = 2,
102
+ use_residual: bool = True,
103
+ dropout: float = 0.0,
104
+ ):
105
+ super().__init__()
106
+
107
+ # Input projection (mixes information between neurons)
108
+ self.input_proj = nn.Linear(dim, dim * expansion_factor)
109
+
110
+ # CfC cell operating on expanded dimension
111
+ self.cfc = CfCCell(dim * expansion_factor)
112
+
113
+ # Output projection back to original dim
114
+ self.output_proj = nn.Linear(dim * expansion_factor, dim)
115
+
116
+ self.use_residual = use_residual
117
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
118
+ self.norm = nn.LayerNorm(dim)
119
+
120
+ def forward(self, x: torch.Tensor, t: float = 1.0) -> torch.Tensor:
121
+ residual = x
122
+
123
+ # Expand → CfC → Contract
124
+ h = self.input_proj(x)
125
+ h = torch.tanh(h) # Activation before CfC (as in paper)
126
+ h = self.cfc(h, t)
127
+ h = self.output_proj(h)
128
+ h = self.dropout(h)
129
+
130
+ if self.use_residual:
131
+ h = residual + h
132
+
133
+ return self.norm(h)
134
+
135
+
136
+ class CfCGate(nn.Module):
137
+ """
138
+ CfC-inspired gating mechanism for Mamba blocks.
139
+
140
+ Instead of the standard SiLU gating used in Mamba, we use a CfC cell
141
+ that computes an adaptive, time-dependent gate value. This gives the
142
+ model the ability to:
143
+ 1. Learn per-token computation depth (like adaptive computation time)
144
+ 2. Smooth temporal dynamics that prevent gradient explosion
145
+ 3. Better handling of long-range dependencies through liquid state
146
+
147
+ The gate is computed as:
148
+ gate = sigmoid(CfC(linear(x)))
149
+
150
+ where CfC gives a smooth, bounded output suitable for gating.
151
+ """
152
+
153
+ def __init__(self, dim: int, hidden_dim: int | None = None):
154
+ super().__init__()
155
+ hidden_dim = hidden_dim or dim
156
+ self.proj = nn.Linear(dim, hidden_dim)
157
+ self.cfc = CfCCell(hidden_dim)
158
+ self.gate_proj = nn.Linear(hidden_dim, dim)
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ h = self.proj(x)
162
+ h = self.cfc(h)
163
+ gate = torch.sigmoid(self.gate_proj(h))
164
+ return gate