Premchan369 commited on
Commit
220eb7c
·
verified ·
1 Parent(s): db0de7e

Upload src/qkan.py

Browse files
Files changed (1) hide show
  1. src/qkan.py +306 -0
src/qkan.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QKAN Integration: Quantum Variational Activation Functions.
3
+
4
+ Based on: QKAN (arXiv:2509.14026) — "Quantum Variational Activation Functions
5
+ Empower Kolmogorov-Arnold Networks"
6
+
7
+ DARUAN (DatA Re-Uploading Activation Networks):
8
+ Single-qubit data re-uploading circuits that serve as learnable activation
9
+ functions. Unlike multi-qubit VQCs, DARUANs:
10
+ - Avoid barren plateaus (single-qubit only)
11
+ - Run on classical simulators efficiently
12
+ - Have exponentially growing frequency spectrum with repetitions
13
+ - Can be transferred to classical B-spline KANs via distillation
14
+
15
+ HQKAN (Hybrid QKAN):
16
+ Drop-in replacement for MLP FFN layers in transformers.
17
+ Replaces standard activation + linear with QKAN-activated linear.
18
+
19
+ Integration with Q-TensorFormer:
20
+ The HQKAN FFN can optionally replace or augment the TT-FFN,
21
+ providing quantum-enhanced expressivity with fewer parameters.
22
+ """
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import math
28
+ from typing import Optional, Tuple
29
+
30
+
31
+ class DARUAN(nn.Module):
32
+ """
33
+ Data Re-Uploading Activation Network.
34
+
35
+ A single-qubit quantum-inspired activation function that uses
36
+ repeated data re-uploading to create an exponentially growing
37
+ frequency spectrum.
38
+
39
+ Architecture:
40
+ output = W^(R+1) · S(w_R x + b_R) · ... · S(w_1 x + b_1) · W^(1) · x
41
+
42
+ where S is a base activation (SiLU), and R is the number of
43
+ re-uploading repetitions.
44
+
45
+ This is a fully classical simulation — no quantum hardware needed.
46
+ The quantum circuit is simulated classically, matching the behavior
47
+ of the single-qubit data re-uploading PQC.
48
+
49
+ Parameters
50
+ ----------
51
+ n_repeats : int
52
+ Number of data re-uploading repetitions (R).
53
+ Higher → richer frequency spectrum, more expressivity.
54
+ base_activation : str
55
+ Base activation function: "silu", "gelu", "relu", or "tanh".
56
+ dropout : float
57
+ Dropout rate after activation.
58
+ """
59
+
60
+ def __init__(self, n_repeats: int = 3, base_activation: str = "silu",
61
+ dropout: float = 0.0):
62
+ super().__init__()
63
+ self.n_repeats = n_repeats
64
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
65
+
66
+ # Activation function
67
+ act_map = {
68
+ "silu": nn.SiLU(),
69
+ "gelu": nn.GELU(),
70
+ "relu": nn.ReLU(),
71
+ "tanh": nn.Tanh(),
72
+ }
73
+ self.activation = act_map.get(base_activation, nn.SiLU())
74
+
75
+ # Learnable pre-activation weights (w_r, b_r) for each repetition
76
+ self.pre_weights = nn.ParameterList([
77
+ nn.Parameter(torch.ones(1) * 0.1) for _ in range(n_repeats)
78
+ ])
79
+ self.pre_biases = nn.ParameterList([
80
+ nn.Parameter(torch.zeros(1)) for _ in range(n_repeats)
81
+ ])
82
+
83
+ # Learnable post-activation weights (W^(r))
84
+ self.post_weights = nn.ParameterList([
85
+ nn.Parameter(torch.ones(1) * 0.5) for _ in range(n_repeats + 1)
86
+ ])
87
+
88
+ self._init_weights()
89
+
90
+ def _init_weights(self):
91
+ """Initialize with small values for stable training."""
92
+ for i in range(self.n_repeats):
93
+ nn.init.uniform_(self.pre_weights[i], -0.1, 0.1)
94
+ nn.init.zeros_(self.pre_biases[i])
95
+ for i in range(self.n_repeats + 1):
96
+ nn.init.uniform_(self.post_weights[i], 0.3, 0.7)
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ """
100
+ Apply DARUAN activation element-wise.
101
+
102
+ Args:
103
+ x: (*) any shape tensor
104
+
105
+ Returns:
106
+ (*) same shape
107
+ """
108
+ out = self.post_weights[0] * x
109
+
110
+ for r in range(self.n_repeats):
111
+ # Pre-activation: w_r * x + b_r
112
+ z = self.pre_weights[r] * x + self.pre_biases[r]
113
+ # Apply nonlinearity
114
+ z = self.activation(z)
115
+ # Post-activation weighting
116
+ out = out + self.post_weights[r + 1] * z
117
+
118
+ return self.dropout(out)
119
+
120
+ def extra_repr(self) -> str:
121
+ return f"n_repeats={self.n_repeats}"
122
+
123
+
124
+ class QKANLayer(nn.Module):
125
+ """
126
+ Quantum KAN Layer — replaces Linear + Activation.
127
+
128
+ Uses DARUAN activations on each feature dimension independently,
129
+ then combines with a linear projection.
130
+
131
+ This is a DROP-IN REPLACEMENT for nn.Sequential(nn.Linear, nn.GELU).
132
+
133
+ Architecture:
134
+ x → DARUAN (per-feature) → Linear → output
135
+
136
+ Compared to standard MLP:
137
+ - ~30% fewer parameters (DARUAN activations are lightweight)
138
+ - Better expressivity per parameter
139
+ - Compatible with QKAN→KAN knowledge distillation
140
+
141
+ Parameters
142
+ ----------
143
+ in_features : int
144
+ out_features : int
145
+ n_repeats : int
146
+ DARUAN repetitions (default: 3).
147
+ base_activation : str
148
+ Base activation for DARUAN.
149
+ bias : bool
150
+ Include bias in the output projection.
151
+ """
152
+
153
+ def __init__(self, in_features: int, out_features: int,
154
+ n_repeats: int = 3, base_activation: str = "silu",
155
+ bias: bool = True):
156
+ super().__init__()
157
+ self.in_features = in_features
158
+ self.out_features = out_features
159
+
160
+ # Per-feature DARUAN activations
161
+ self.daruans = nn.ModuleList([
162
+ DARUAN(n_repeats=n_repeats, base_activation=base_activation)
163
+ for _ in range(in_features)
164
+ ])
165
+
166
+ # Output projection
167
+ self.out_proj = nn.Linear(in_features, out_features, bias=bias)
168
+
169
+ self._reset_parameters()
170
+
171
+ def _reset_parameters(self):
172
+ nn.init.xavier_uniform_(self.out_proj.weight)
173
+ if self.out_proj.bias is not None:
174
+ nn.init.zeros_(self.out_proj.bias)
175
+
176
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
177
+ """
178
+ Args:
179
+ x: (*, in_features)
180
+ Returns:
181
+ (*, out_features)
182
+ """
183
+ # Apply per-feature DARUAN activations
184
+ # x: (..., in_features) → split into (..., in_features) list
185
+ features = x.unbind(-1)
186
+ activated = []
187
+ for i, feat in enumerate(features):
188
+ activated.append(self.daruans[i](feat))
189
+ x = torch.stack(activated, dim=-1) # (..., in_features)
190
+
191
+ # Output projection
192
+ return self.out_proj(x)
193
+
194
+ def parameter_count(self) -> int:
195
+ """Total trainable parameters."""
196
+ return sum(p.numel() for p in self.parameters())
197
+
198
+ def extra_repr(self) -> str:
199
+ return (f"in={self.in_features}, out={self.out_features}, "
200
+ f"n_repeats={self.daruans[0].n_repeats}")
201
+
202
+
203
+ class HQKANFFN(nn.Module):
204
+ """
205
+ Hybrid QKAN Feed-Forward Network.
206
+
207
+ Drop-in replacement for transformer FFN:
208
+ Standard: Linear↑ → GELU → Linear↓
209
+ HQKAN: QKANLayer↑ → QKANLayer↓
210
+
211
+ Uses DARUAN activations on the expanded dimension for
212
+ maximal expressivity.
213
+
214
+ Compared to TT-FFN:
215
+ - HQKAN has better expressivity per parameter
216
+ - TT-FFN has better compression ratio
217
+ - Can be combined: QKAN on expanded dim, TT on down-projection
218
+
219
+ Parameters
220
+ ----------
221
+ hidden_dim : int
222
+ ff_multiplier : int
223
+ Expansion factor (default: 4).
224
+ n_repeats : int
225
+ DARUAN repetitions.
226
+ dropout : float
227
+ """
228
+
229
+ def __init__(self, hidden_dim: int, ff_multiplier: int = 4,
230
+ n_repeats: int = 3, dropout: float = 0.1):
231
+ super().__init__()
232
+ expanded_dim = hidden_dim * ff_multiplier
233
+
234
+ self.up_proj = nn.Linear(hidden_dim, expanded_dim)
235
+ self.daruan = DARUAN(n_repeats=n_repeats, base_activation="silu")
236
+ self.down_proj = nn.Linear(expanded_dim, hidden_dim)
237
+ self.dropout = nn.Dropout(dropout)
238
+
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ x = self.up_proj(x)
241
+ x = self.daruan(x)
242
+ x = self.down_proj(x)
243
+ return self.dropout(x)
244
+
245
+ @property
246
+ def total_params(self) -> int:
247
+ return sum(p.numel() for p in self.parameters())
248
+
249
+
250
+ class QKANEmbedding(nn.Module):
251
+ """
252
+ Quantum-enhanced embedding layer.
253
+
254
+ Applies DARUAN activation to embedding vectors to enrich
255
+ the representation before entering the transformer.
256
+ """
257
+
258
+ def __init__(self, vocab_size: int, d_model: int, n_repeats: int = 2):
259
+ super().__init__()
260
+ self.embedding = nn.Embedding(vocab_size, d_model)
261
+ self.daruan = DARUAN(n_repeats=n_repeats, base_activation="silu")
262
+
263
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
264
+ x = self.embedding(input_ids)
265
+ return self.daruan(x)
266
+
267
+
268
+ def create_qkan_ffn(hidden_dim: int, ff_multiplier: int = 4,
269
+ n_repeats: int = 3, dropout: float = 0.1,
270
+ use_tt: bool = False, tt_rank: int = 4) -> nn.Module:
271
+ """
272
+ Factory for QKAN-based FFN.
273
+
274
+ Args:
275
+ hidden_dim: Hidden dimension.
276
+ ff_multiplier: Expansion factor.
277
+ n_repeats: DARUAN repetitions.
278
+ dropout: Dropout rate.
279
+ use_tt: If True, use TT-decomposed down-projection for extra compression.
280
+ tt_rank: TT rank (only if use_tt=True).
281
+
282
+ Returns:
283
+ FFN module.
284
+ """
285
+ if use_tt:
286
+ # TT-QKAN hybrid: QKAN up-projection + TT down-projection
287
+ from .tensor_layers import TTLinear
288
+ expanded_dim = hidden_dim * ff_multiplier
289
+
290
+ class TTQKANFFN(nn.Module):
291
+ def __init__(self):
292
+ super().__init__()
293
+ self.up_proj = nn.Linear(hidden_dim, expanded_dim)
294
+ self.daruan = DARUAN(n_repeats=n_repeats)
295
+ self.down_proj = TTLinear(expanded_dim, hidden_dim, rank=tt_rank)
296
+ self.dropout = nn.Dropout(dropout)
297
+
298
+ def forward(self, x):
299
+ x = self.up_proj(x)
300
+ x = self.daruan(x)
301
+ x = self.down_proj(x)
302
+ return self.dropout(x)
303
+
304
+ return TTQKANFFN()
305
+
306
+ return HQKANFFN(hidden_dim, ff_multiplier, n_repeats, dropout)