Premchan369 commited on
Commit
79a43db
·
verified ·
1 Parent(s): ea80a93

Upload q_tensor_former.py

Browse files
Files changed (1) hide show
  1. q_tensor_former.py +493 -0
q_tensor_former.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Q-TensorFormer: Quantum-Enhanced Tensor Network LLM Compression Engine
4
+ =======================================================================
5
+ Hybrid quantum-tensor transformer with:
6
+ - Pure PyTorch Tensor-Train FFN layers (no compiled deps)
7
+ - PennyLane quantum angle encoding with TorchLayer
8
+ - Entanglement-guided adaptive rank scheduling
9
+ - Selective quantum routing (only "hard" tokens)
10
+ - Full benchmark against identical-architecture baseline
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import math, os
17
+ from typing import Optional, Tuple
18
+ from dataclasses import dataclass
19
+
20
+ import pennylane as qml
21
+
22
+ print("=" * 65)
23
+ print(" Q-TENSORFORMER: Quantum-Tensor LLM Compressor")
24
+ print("=" * 65)
25
+ print(f" PyTorch {torch.__version__} | PennyLane {qml.__version__}")
26
+ print()
27
+
28
+ # ═════════════════════════════════════════════════════════════════════
29
+ # CONFIG
30
+ # ═════════════════════════════════════════════════════════════════════
31
+
32
+ @dataclass
33
+ class CFG:
34
+ d_model: int = 64
35
+ n_heads: int = 4
36
+ n_layers: int = 2
37
+ ff_mult: int = 4
38
+ max_seq: int = 64
39
+ vocab: int = 1000
40
+ tt_rank: int = 8
41
+ min_rank: int = 2
42
+ q_qubits: int = 4
43
+ q_layers: int = 2
44
+ q_sparsity: float = 0.3
45
+ dropout: float = 0.1
46
+ lr: float = 3e-4
47
+ rank_alpha: float = 2.0
48
+ rank_smoothing: float = 0.9
49
+
50
+ # ═════════════════════════════════════════════════════════════════════
51
+ # 1. PURE PYTORCH TENSOR-TRAIN LINEAR LAYER
52
+ # ═════════════════════════════════════════════════════════════════════
53
+
54
+ def auto_factor(n, max_f=4):
55
+ if n <= 1: return (1, 1)
56
+ f, r = [], n
57
+ for p in [2,2,2,2,2,3,3,5,7]:
58
+ while r % p == 0 and len(f) < max_f:
59
+ f.append(p); r //= p
60
+ if r > 1:
61
+ if len(f) < max_f: f.append(r)
62
+ else: f[-1] *= r
63
+ while len(f) < 2: f.insert(0, 1)
64
+ return tuple(f[:max_f])
65
+
66
+ class TTLinear(nn.Module):
67
+ """Tensor-Train decomposed linear layer. Pure PyTorch, zero compiled deps."""
68
+ def __init__(self, in_shape, out_shape, rank=8, bias=True):
69
+ super().__init__()
70
+ in_shape = tuple(in_shape)
71
+ out_shape = tuple(out_shape)
72
+ max_d = max(len(in_shape), len(out_shape))
73
+ in_shape = (1,) * (max_d - len(in_shape)) + in_shape
74
+ out_shape = (1,) * (max_d - len(out_shape)) + out_shape
75
+ assert len(in_shape) == len(out_shape)
76
+ self.in_shape, self.out_shape = in_shape, out_shape
77
+ self.rank, self.ndim = rank, len(in_shape)
78
+ self.in_feat = math.prod(in_shape)
79
+ self.out_feat = math.prod(out_shape)
80
+ self.cores = nn.ParameterList()
81
+ for k in range(self.ndim):
82
+ rl = 1 if k == 0 else rank
83
+ rr = 1 if k == self.ndim - 1 else rank
84
+ c = torch.empty(rl, out_shape[k], in_shape[k], rr)
85
+ bnd = math.sqrt(6.0 / max(1, rl*in_shape[k] + rr*out_shape[k]))
86
+ nn.init.uniform_(c, -bnd, bnd)
87
+ self.cores.append(c)
88
+ self.bias = nn.Parameter(torch.zeros(self.out_feat)) if bias else None
89
+ tp = sum(c.numel() for c in self.cores) + (self.bias.numel() if bias else 0)
90
+ self.compr = (self.in_feat * self.out_feat) / max(tp, 1)
91
+
92
+ def forward(self, x):
93
+ bs = x.shape[:-1]
94
+ B = math.prod(bs)
95
+ x = x.reshape(B, self.in_feat)
96
+ state = x.reshape(B, *self.in_shape)
97
+
98
+ for k in range(self.ndim):
99
+ core = self.cores[k]
100
+ r_k, o_k, i_k, r_kp1 = core.shape
101
+
102
+ if k == 0:
103
+ rest = math.prod(self.in_shape[1:])
104
+ s = state.reshape(B, i_k, rest)
105
+ cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1)
106
+ s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1))
107
+ s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1)
108
+ state = s.reshape(B, r_kp1, -1)
109
+
110
+ elif k == self.ndim - 1:
111
+ prev_os = math.prod(self.out_shape[:k])
112
+ s = state.reshape(B, r_k, prev_os, i_k)
113
+ cm = core.squeeze(-1)
114
+ s = torch.einsum('brpi,roi->bpo', s, cm)
115
+ state = s.reshape(B, prev_os * o_k)
116
+
117
+ else:
118
+ prev_os = math.prod(self.out_shape[:k])
119
+ rest_in = math.prod(self.in_shape[k+1:])
120
+ s = state.reshape(B, r_k, prev_os * i_k * rest_in)
121
+ s = s.reshape(B, r_k, prev_os, i_k, rest_in)
122
+ s = torch.einsum('brpix,roiq->bpoqx', s, core)
123
+ s = s.permute(0, 3, 1, 2, 4)
124
+ state = s.reshape(B, r_kp1, prev_os * o_k * rest_in)
125
+
126
+ out = state.reshape(B, self.out_feat)
127
+ if self.bias is not None: out = out + self.bias
128
+ return out.reshape(*bs, self.out_feat)
129
+
130
+ def set_rank(self, nr):
131
+ for i, c in enumerate(self.cores):
132
+ s = [slice(None)]*4
133
+ if i > 0: s[0] = slice(None, nr)
134
+ if i < self.ndim - 1: s[3] = slice(None, nr)
135
+ self.cores[i] = nn.Parameter(c[tuple(s)].clone())
136
+
137
+ # ═════════════════════════════════════════════════════════════════════
138
+ # 2. QUANTUM ANGLE EMBEDDING (PennyLane)
139
+ # ═════════════════════════════════════════════════════════════════════
140
+
141
+ class QuantumEmbed(nn.Module):
142
+ """Angle embedding → variational circuit → PauliZ expectations."""
143
+ def __init__(self, n_q=4, layers=2, n_out=None):
144
+ super().__init__()
145
+ self.n_q, self.layers = n_q, layers
146
+ n_out = n_out or n_q
147
+ dev = qml.device("default.qubit", wires=n_q)
148
+
149
+ @qml.qnode(dev, interface="torch", diff_method="backprop")
150
+ def circ(inputs, w):
151
+ for i in range(n_q): qml.RX(inputs[..., i], wires=i)
152
+ for L in range(layers):
153
+ for i in range(n_q): qml.RY(w[L, i], wires=i)
154
+ for i in range(n_q-1): qml.CNOT(wires=[i, i+1])
155
+ if n_q > 2: qml.CNOT(wires=[n_q-1, 0])
156
+ return [qml.expval(qml.PauliZ(i)) for i in range(n_out)]
157
+
158
+ self.qlayer = qml.qnn.TorchLayer(circ, {"w": (layers, n_q)})
159
+
160
+ def forward(self, x): return self.qlayer(x)
161
+
162
+ # ═════════════════════════════════════════════════════════════════════
163
+ # 3. TT FEED-FORWARD
164
+ # ═════════════════════════════════════════════════════════════════════
165
+
166
+ class TTFFN(nn.Module):
167
+ def __init__(self, D, ff_mult=4, rank=8):
168
+ super().__init__()
169
+ E = D * ff_mult
170
+ self.up = TTLinear(auto_factor(D), auto_factor(E), rank, True)
171
+ self.down = TTLinear(auto_factor(E), auto_factor(D), rank, True)
172
+ def forward(self, x): return self.down(F.gelu(self.up(x)))
173
+ def set_rank(self, r): self.up.set_rank(r); self.down.set_rank(r)
174
+
175
+ # ═════════════════════════════════════════════════════════════════════
176
+ # 4. RANK SCHEDULER
177
+ # ═════════════════════════════════════════════════════════════════════
178
+
179
+ class RankScheduler(nn.Module):
180
+ """rank = r_min + alpha * entropy (EMA-smoothed)"""
181
+ def __init__(self, mn=2, mx=16, a=2.0, sm=0.9):
182
+ super().__init__()
183
+ self.mn, self.mx = mn, mx
184
+ self.alpha = nn.Parameter(torch.tensor(a))
185
+ self.sm = sm
186
+ self.register_buffer('ema', torch.tensor(0.5))
187
+ self.register_buffer('cur', torch.tensor(float(mx)))
188
+ def forward(self, ent):
189
+ s = ent.mean().detach() if ent.numel()>1 else ent.detach()
190
+ self.ema = self.sm*self.ema + (1-self.sm)*s
191
+ raw = self.mn + self.alpha*self.ema
192
+ r = int(torch.clamp(raw, self.mn, self.mx).round().item())
193
+ if self.training: self.cur.fill_(r)
194
+ return r
195
+ @property
196
+ def current(self): return int(self.cur.item())
197
+
198
+ # ═════════════════════════════════════════════════════════════════════
199
+ # 5. QUANTUM ROUTER
200
+ # ═════════════════════════════════════════════════════════════════════
201
+
202
+ class QuantumRouter(nn.Module):
203
+ """Learned gate: routes only hard tokens through quantum circuit."""
204
+ def __init__(self, D, qmod, thr=0.5):
205
+ super().__init__()
206
+ self.qmod = qmod
207
+ self.thr = thr
208
+ self.gate = nn.Sequential(
209
+ nn.Linear(D, D//4), nn.ReLU(), nn.Linear(D//4,1), nn.Sigmoid())
210
+ self.register_buffer('tot', torch.tensor(0.0))
211
+ self.register_buffer('qtok', torch.tensor(0.0))
212
+ def forward(self, x):
213
+ B,S,D = x.shape
214
+ g = self.gate(x.reshape(-1,D)).squeeze(-1).reshape(B,S)
215
+ m = (g > self.thr).float()
216
+ if self.training:
217
+ m = m.detach() + g - g.detach()
218
+ xf = x.reshape(-1,D); mf = m.reshape(-1)
219
+ sel = xf[mf > 0.5]; out = xf.clone()
220
+ if sel.shape[0]>0:
221
+ qo = self.qmod(sel)
222
+ if qo.shape[-1]!=D:
223
+ if not hasattr(self,'_proj'):
224
+ self._proj = nn.Linear(qo.shape[-1],D).to(x.device)
225
+ qo = self._proj(qo)
226
+ out[mf > 0.5] = qo.to(out.dtype)
227
+ self.tot += B*S; self.qtok += m.sum()
228
+ return out.reshape(B,S,D), g
229
+ def sparsity(self):
230
+ if self.tot>0: return 1.0-(self.qtok/self.tot).item()
231
+ return 1.0
232
+
233
+ # ═════════════════════════════════════════════════════════════════════
234
+ # 6. ATTENTION
235
+ # ═════════════════════════════════════════════════════════════════════
236
+
237
+ class MHA(nn.Module):
238
+ def __init__(self, D, heads=4, drop=0.1):
239
+ super().__init__()
240
+ assert D%heads==0
241
+ self.h, self.hd = heads, D//heads
242
+ self.scale = self.hd**-0.5
243
+ self.qkv = nn.Linear(D, 3*D, bias=False)
244
+ self.out = nn.Linear(D, D)
245
+ self.drop = nn.Dropout(drop)
246
+ def forward(self, x, mask=None):
247
+ B,S,D = x.shape
248
+ qkv = self.qkv(x).reshape(B,S,3,self.h,self.hd).permute(2,0,3,1,4)
249
+ q,k,v = qkv[0], qkv[1], qkv[2]
250
+ a = (q@k.transpose(-2,-1))*self.scale
251
+ if mask is not None:
252
+ a = a.masked_fill(mask[:,None,None,:]==0, float('-inf'))
253
+ aw = F.softmax(a, dim=-1); aw = self.drop(aw)
254
+ o = (aw@v).transpose(1,2).reshape(B,S,D)
255
+ return self.out(o), aw
256
+
257
+ # ═════════════════════════════════════════════════════════════════════
258
+ # 7. HYBRID BLOCK
259
+ # ═════════════════════════════════════════════════════════════════════
260
+
261
+ class HybridBlock(nn.Module):
262
+ def __init__(self, cfg):
263
+ super().__init__()
264
+ D = cfg.d_model
265
+ self.a_norm = nn.LayerNorm(D)
266
+ self.attn = MHA(D, cfg.n_heads, cfg.dropout)
267
+ self.f_norm = nn.LayerNorm(D)
268
+ self.ffn = TTFFN(D, cfg.ff_mult, cfg.tt_rank)
269
+ self.qrouter = None
270
+ if cfg.q_qubits:
271
+ qc = QuantumEmbed(cfg.q_qubits, cfg.q_layers, cfg.q_qubits)
272
+ qw = nn.Sequential(nn.Linear(D, cfg.q_qubits), qc)
273
+ self.qrouter = QuantumRouter(D, qw)
274
+ self.rs = RankScheduler(cfg.min_rank, cfg.tt_rank, cfg.rank_alpha, cfg.rank_smoothing)
275
+ self.drop = nn.Dropout(cfg.dropout)
276
+ def forward(self, x, mask=None, adapt=True):
277
+ ao, aw = self.attn(self.a_norm(x), mask)
278
+ x = x + self.drop(ao)
279
+ eps=1e-8
280
+ ent = -torch.sum(aw*torch.log(aw+eps), dim=-1).mean(dim=-1).mean()
281
+ tr = self.rs(ent) if adapt else self.rs.mx
282
+ if adapt: self.ffn.set_rank(tr)
283
+ n = self.f_norm(x)
284
+ qs = 1.0
285
+ if self.qrouter is not None:
286
+ qo, _ = self.qrouter(n)
287
+ n = n + self.drop(qo - n.detach() + n)
288
+ qs = self.qrouter.sparsity()
289
+ x = x + self.drop(self.ffn(n))
290
+ return {'out':x, 'aw':aw, 'entropy':ent, 'rank':tr, 'qsparse':qs}
291
+
292
+ # ═════════════════════════════════════════════════════════════════════
293
+ # 8. Q-TENSORFORMER MODEL
294
+ # ═════════════════════════════════════════════════════════════════════
295
+
296
+ class QTensorFormer(nn.Module):
297
+ def __init__(self, cfg):
298
+ super().__init__()
299
+ self.cfg = cfg
300
+ self.tok = nn.Embedding(cfg.vocab, cfg.d_model)
301
+ self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02)
302
+ self.layers = nn.ModuleList([HybridBlock(cfg) for _ in range(cfg.n_layers)])
303
+ self.norm = nn.LayerNorm(cfg.d_model)
304
+ self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False)
305
+ self.head.weight = self.tok.weight
306
+ self._init()
307
+ def _init(self):
308
+ for p in self.parameters():
309
+ if p.dim()>=2: nn.init.xavier_uniform_(p)
310
+ def forward(self, ids, mask=None, adapt=True):
311
+ B,S = ids.shape
312
+ x = self.tok(ids) + self.pos[:,:S,:]
313
+ if mask is not None: mask = mask[:,None,None,:]
314
+ bos = []
315
+ for l in self.layers:
316
+ o = l(x, mask, adapt); x=o['out']; bos.append(o)
317
+ x = self.norm(x); logits = self.head(x)
318
+ ent = torch.stack([b['entropy'] for b in bos]).mean()
319
+ rk = sum(b['rank'] for b in bos)/len(bos)
320
+ qs = sum(b['qsparse'] for b in bos)/len(bos)
321
+ return {'logits':logits,'entropy':ent,'rank':rk,'qsparse':qs}
322
+ def loss(self, ids, mask=None, labels=None):
323
+ if labels is None: labels=ids.clone()
324
+ out = self(ids, mask)
325
+ sl = out['logits'][:,:-1].contiguous()
326
+ ll = labels[:,1:].contiguous()
327
+ l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100)
328
+ return {'loss':l,'ppl':torch.exp(l),'entropy':out['entropy'],'rank':out['rank'],'qsparse':out['qsparse']}
329
+ def nparams(self):
330
+ t = sum(p.numel() for p in self.parameters())
331
+ tr = sum(p.numel() for p in self.parameters() if p.requires_grad)
332
+ return {'total':t,'trainable':tr}
333
+
334
+ # ═════════════════════════════════════════════════════════════════════
335
+ # 9. BASELINE (identical architecture, dense FFN)
336
+ # ═════════════════════════════════════════════════════════════════════
337
+
338
+ class Baseline(nn.Module):
339
+ def __init__(self, cfg):
340
+ super().__init__()
341
+ self.cfg = cfg
342
+ self.tok = nn.Embedding(cfg.vocab, cfg.d_model)
343
+ self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02)
344
+ self.drop = nn.Dropout(cfg.dropout)
345
+ self.layers = nn.ModuleList()
346
+ for _ in range(cfg.n_layers):
347
+ self.layers.append(nn.ModuleDict({
348
+ 'a_n': nn.LayerNorm(cfg.d_model),
349
+ 'a': MHA(cfg.d_model, cfg.n_heads, cfg.dropout),
350
+ 'f_n': nn.LayerNorm(cfg.d_model),
351
+ 'ff': nn.Sequential(
352
+ nn.Linear(cfg.d_model, cfg.d_model*cfg.ff_mult),
353
+ nn.GELU(), nn.Dropout(cfg.dropout),
354
+ nn.Linear(cfg.d_model*cfg.ff_mult, cfg.d_model)),
355
+ }))
356
+ self.norm = nn.LayerNorm(cfg.d_model)
357
+ self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False)
358
+ self.head.weight = self.tok.weight
359
+ self._init()
360
+ def _init(self):
361
+ for p in self.parameters():
362
+ if p.dim()>=2: nn.init.xavier_uniform_(p)
363
+ def forward(self, ids, mask=None):
364
+ B,S = ids.shape
365
+ x = self.tok(ids)+self.pos[:,:S,:]; x=self.drop(x)
366
+ m = mask[:,None,None,:] if mask is not None else None
367
+ for l in self.layers:
368
+ ao,_ = l['a'](l['a_n'](x),m); x=x+self.drop(ao)
369
+ x = x+self.drop(l['ff'](l['f_n'](x)))
370
+ return {'logits':self.head(self.norm(x))}
371
+ def loss(self, ids, mask=None, labels=None):
372
+ if labels is None: labels=ids.clone()
373
+ out = self(ids, mask)
374
+ sl = out['logits'][:,:-1].contiguous()
375
+ ll = labels[:,1:].contiguous()
376
+ l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100)
377
+ return {'loss':l,'ppl':torch.exp(l)}
378
+ def nparams(self):
379
+ t = sum(p.numel() for p in self.parameters())
380
+ tr = sum(p.numel() for p in self.parameters() if p.requires_grad)
381
+ return {'total':t,'trainable':tr}
382
+
383
+ # ═════════════════════════════════════════════════════════════════════
384
+ # 10. TRAINING UTILITIES
385
+ # ═════════════════════════════════════════════════════════════════════
386
+
387
+ def make_data(vocab=1000, seq=64, n=500, bs=16):
388
+ d = torch.randint(1, vocab, (n, seq))
389
+ ds = torch.utils.data.TensorDataset(d)
390
+ return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True,
391
+ collate_fn=lambda batch: {'input_ids': torch.stack([item[0] for item in batch])})
392
+
393
+ def train_epoch(model, dl, opt, sched, e, tag="M"):
394
+ model.train(); tl,tp,nb = 0.0,0.0,0; ex={}
395
+ for b in dl:
396
+ ids = b['input_ids']; m = b.get('attention_mask')
397
+ opt.zero_grad()
398
+ out = model.loss(ids, m); out['loss'].backward()
399
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
400
+ opt.step()
401
+ if sched: sched.step()
402
+ tl += out['loss'].item(); tp += out['ppl'].item(); nb += 1
403
+ for k in ['entropy','rank','qsparse']:
404
+ if k in out: ex[k]=ex.get(k,0.0)+(out[k].item() if isinstance(out[k],torch.Tensor) else out[k])
405
+ al,ap = tl/max(nb,1), tp/max(nb,1)
406
+ s = f"[{tag}] E{e:2d} loss={al:.4f} ppl={ap:6.1f}"
407
+ for k,v in ex.items(): s+=f" {k}={v/max(nb,1):.3f}"
408
+ print(s); return al,ap
409
+
410
+ @torch.no_grad()
411
+ def evaluate(model, dl):
412
+ model.eval(); tl,tp,nb=0.0,0.0,0
413
+ for b in dl:
414
+ ids=b['input_ids']; m=b.get('attention_mask')
415
+ out=model.loss(ids,m); tl+=out['loss'].item(); tp+=out['ppl'].item(); nb+=1
416
+ return tl/max(nb,1), tp/max(nb,1)
417
+
418
+ # ═════════════════════════════════════════════════════════════════════
419
+ # 11. MAIN BENCHMARK
420
+ # ═════════════════════════════════════════════════════════════════════
421
+
422
+ def main():
423
+ torch.manual_seed(42)
424
+ cfg = CFG(d_model=64, n_layers=2, n_heads=4, tt_rank=8,
425
+ q_qubits=4, q_sparsity=0.3, vocab=1000, max_seq=64)
426
+
427
+ print(f"Config: d={cfg.d_model} layers={cfg.n_layers} heads={cfg.n_heads} rank={cfg.tt_rank}")
428
+ print(f"Quantum: qubits={cfg.q_qubits} sparsity={cfg.q_sparsity}")
429
+ print(f"Tensor FFN: ON\n")
430
+
431
+ qt = QTensorFormer(cfg)
432
+ bl = Baseline(cfg)
433
+
434
+ pq = qt.nparams(); pb = bl.nparams()
435
+ print(f"Q-TensorFormer params: {pq['trainable']:>10,}")
436
+ print(f"Baseline params: {pb['trainable']:>10,}")
437
+ print(f"Compression ratio: {pb['trainable']/max(pq['trainable'],1):>10.1f}x\n")
438
+
439
+ train_dl = make_data(cfg.vocab, cfg.max_seq, 500, 16)
440
+ val_dl = make_data(cfg.vocab, cfg.max_seq, 100, 16)
441
+ E = 8
442
+
443
+ print("=" * 50)
444
+ print(" TRAINING Q-TENSORFORMER")
445
+ print("=" * 50)
446
+ oq = torch.optim.AdamW(qt.parameters(), lr=cfg.lr)
447
+ sq = torch.optim.lr_scheduler.CosineAnnealingLR(oq, E*len(train_dl))
448
+ for e in range(1, E+1): train_epoch(qt, train_dl, oq, sq, e, "Q-TF")
449
+
450
+ print("\n" + "=" * 50)
451
+ print(" TRAINING BASELINE")
452
+ print("=" * 50)
453
+ ob = torch.optim.AdamW(bl.parameters(), lr=cfg.lr)
454
+ sb = torch.optim.lr_scheduler.CosineAnnealingLR(ob, E*len(train_dl))
455
+ for e in range(1, E+1): train_epoch(bl, train_dl, ob, sb, e, "BSL")
456
+
457
+ ql,qp = evaluate(qt, val_dl)
458
+ bl_val,bp = evaluate(bl, val_dl)
459
+
460
+ torch.save(qt.state_dict(), '/tmp/qt.pt')
461
+ torch.save(bl.state_dict(), '/tmp/bl.pt')
462
+ qsz = os.path.getsize('/tmp/qt.pt')/(1024*1024)
463
+ bsz = os.path.getsize('/tmp/bl.pt')/(1024*1024)
464
+
465
+ print("\n" + "=" * 65)
466
+ print(" RESULTS")
467
+ print("=" * 65)
468
+ print(f"{'Metric':<30} {'Q-TensorFormer':>15} {'Baseline':>15}")
469
+ print("-" * 60)
470
+ print(f"{'Parameters':<30} {pq['trainable']:>13,} {pb['trainable']:>13,}")
471
+ print(f"{'Val Loss':<30} {ql:>15.4f} {bl_val:>15.4f}")
472
+ print(f"{'Val Perplexity':<30} {qp:>15.2f} {bp:>15.2f}")
473
+ print(f"{'Model Size (MB)':<30} {qsz:>15.1f} {bsz:>15.1f}")
474
+
475
+ ps = (1-pq['trainable']/pb['trainable'])*100
476
+ ss = (1-qsz/bsz)*100
477
+ pr = qp/bp
478
+ print(f"\nParameter reduction: {ps:.1f}%")
479
+ print(f"Size reduction: {ss:.1f}%")
480
+ print(f"PPL ratio (Q-TF/BL): {pr:.2f}x")
481
+
482
+ if pr < 1.1:
483
+ print(f"\n >> VERDICT: Significant compression with minimal quality loss! <<")
484
+ elif pr < 1.3:
485
+ print(f"\n >> VERDICT: Moderate trade-off — compression worth the cost <<")
486
+ else:
487
+ print(f"\n >> VERDICT: Quality gap too large, needs tuning <<")
488
+
489
+ print("\nDone!")
490
+ return {'params_q':pq['trainable'],'params_b':pb['trainable'],'qloss':ql,'qppl':qp,'bloss':bl_val,'bppl':bp,'qsz':qsz,'bsz':bsz,'comp':ps,'sred':ss,'ppl_ratio':pr}
491
+
492
+ if __name__ == '__main__':
493
+ results = main()