Upload q_tensor_former.py
Browse files- q_tensor_former.py +493 -0
q_tensor_former.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Q-TensorFormer: Quantum-Enhanced Tensor Network LLM Compression Engine
|
| 4 |
+
=======================================================================
|
| 5 |
+
Hybrid quantum-tensor transformer with:
|
| 6 |
+
- Pure PyTorch Tensor-Train FFN layers (no compiled deps)
|
| 7 |
+
- PennyLane quantum angle encoding with TorchLayer
|
| 8 |
+
- Entanglement-guided adaptive rank scheduling
|
| 9 |
+
- Selective quantum routing (only "hard" tokens)
|
| 10 |
+
- Full benchmark against identical-architecture baseline
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import math, os
|
| 17 |
+
from typing import Optional, Tuple
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
|
| 20 |
+
import pennylane as qml
|
| 21 |
+
|
| 22 |
+
print("=" * 65)
|
| 23 |
+
print(" Q-TENSORFORMER: Quantum-Tensor LLM Compressor")
|
| 24 |
+
print("=" * 65)
|
| 25 |
+
print(f" PyTorch {torch.__version__} | PennyLane {qml.__version__}")
|
| 26 |
+
print()
|
| 27 |
+
|
| 28 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 29 |
+
# CONFIG
|
| 30 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class CFG:
|
| 34 |
+
d_model: int = 64
|
| 35 |
+
n_heads: int = 4
|
| 36 |
+
n_layers: int = 2
|
| 37 |
+
ff_mult: int = 4
|
| 38 |
+
max_seq: int = 64
|
| 39 |
+
vocab: int = 1000
|
| 40 |
+
tt_rank: int = 8
|
| 41 |
+
min_rank: int = 2
|
| 42 |
+
q_qubits: int = 4
|
| 43 |
+
q_layers: int = 2
|
| 44 |
+
q_sparsity: float = 0.3
|
| 45 |
+
dropout: float = 0.1
|
| 46 |
+
lr: float = 3e-4
|
| 47 |
+
rank_alpha: float = 2.0
|
| 48 |
+
rank_smoothing: float = 0.9
|
| 49 |
+
|
| 50 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 51 |
+
# 1. PURE PYTORCH TENSOR-TRAIN LINEAR LAYER
|
| 52 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 53 |
+
|
| 54 |
+
def auto_factor(n, max_f=4):
|
| 55 |
+
if n <= 1: return (1, 1)
|
| 56 |
+
f, r = [], n
|
| 57 |
+
for p in [2,2,2,2,2,3,3,5,7]:
|
| 58 |
+
while r % p == 0 and len(f) < max_f:
|
| 59 |
+
f.append(p); r //= p
|
| 60 |
+
if r > 1:
|
| 61 |
+
if len(f) < max_f: f.append(r)
|
| 62 |
+
else: f[-1] *= r
|
| 63 |
+
while len(f) < 2: f.insert(0, 1)
|
| 64 |
+
return tuple(f[:max_f])
|
| 65 |
+
|
| 66 |
+
class TTLinear(nn.Module):
|
| 67 |
+
"""Tensor-Train decomposed linear layer. Pure PyTorch, zero compiled deps."""
|
| 68 |
+
def __init__(self, in_shape, out_shape, rank=8, bias=True):
|
| 69 |
+
super().__init__()
|
| 70 |
+
in_shape = tuple(in_shape)
|
| 71 |
+
out_shape = tuple(out_shape)
|
| 72 |
+
max_d = max(len(in_shape), len(out_shape))
|
| 73 |
+
in_shape = (1,) * (max_d - len(in_shape)) + in_shape
|
| 74 |
+
out_shape = (1,) * (max_d - len(out_shape)) + out_shape
|
| 75 |
+
assert len(in_shape) == len(out_shape)
|
| 76 |
+
self.in_shape, self.out_shape = in_shape, out_shape
|
| 77 |
+
self.rank, self.ndim = rank, len(in_shape)
|
| 78 |
+
self.in_feat = math.prod(in_shape)
|
| 79 |
+
self.out_feat = math.prod(out_shape)
|
| 80 |
+
self.cores = nn.ParameterList()
|
| 81 |
+
for k in range(self.ndim):
|
| 82 |
+
rl = 1 if k == 0 else rank
|
| 83 |
+
rr = 1 if k == self.ndim - 1 else rank
|
| 84 |
+
c = torch.empty(rl, out_shape[k], in_shape[k], rr)
|
| 85 |
+
bnd = math.sqrt(6.0 / max(1, rl*in_shape[k] + rr*out_shape[k]))
|
| 86 |
+
nn.init.uniform_(c, -bnd, bnd)
|
| 87 |
+
self.cores.append(c)
|
| 88 |
+
self.bias = nn.Parameter(torch.zeros(self.out_feat)) if bias else None
|
| 89 |
+
tp = sum(c.numel() for c in self.cores) + (self.bias.numel() if bias else 0)
|
| 90 |
+
self.compr = (self.in_feat * self.out_feat) / max(tp, 1)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
bs = x.shape[:-1]
|
| 94 |
+
B = math.prod(bs)
|
| 95 |
+
x = x.reshape(B, self.in_feat)
|
| 96 |
+
state = x.reshape(B, *self.in_shape)
|
| 97 |
+
|
| 98 |
+
for k in range(self.ndim):
|
| 99 |
+
core = self.cores[k]
|
| 100 |
+
r_k, o_k, i_k, r_kp1 = core.shape
|
| 101 |
+
|
| 102 |
+
if k == 0:
|
| 103 |
+
rest = math.prod(self.in_shape[1:])
|
| 104 |
+
s = state.reshape(B, i_k, rest)
|
| 105 |
+
cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1)
|
| 106 |
+
s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1))
|
| 107 |
+
s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1)
|
| 108 |
+
state = s.reshape(B, r_kp1, -1)
|
| 109 |
+
|
| 110 |
+
elif k == self.ndim - 1:
|
| 111 |
+
prev_os = math.prod(self.out_shape[:k])
|
| 112 |
+
s = state.reshape(B, r_k, prev_os, i_k)
|
| 113 |
+
cm = core.squeeze(-1)
|
| 114 |
+
s = torch.einsum('brpi,roi->bpo', s, cm)
|
| 115 |
+
state = s.reshape(B, prev_os * o_k)
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
prev_os = math.prod(self.out_shape[:k])
|
| 119 |
+
rest_in = math.prod(self.in_shape[k+1:])
|
| 120 |
+
s = state.reshape(B, r_k, prev_os * i_k * rest_in)
|
| 121 |
+
s = s.reshape(B, r_k, prev_os, i_k, rest_in)
|
| 122 |
+
s = torch.einsum('brpix,roiq->bpoqx', s, core)
|
| 123 |
+
s = s.permute(0, 3, 1, 2, 4)
|
| 124 |
+
state = s.reshape(B, r_kp1, prev_os * o_k * rest_in)
|
| 125 |
+
|
| 126 |
+
out = state.reshape(B, self.out_feat)
|
| 127 |
+
if self.bias is not None: out = out + self.bias
|
| 128 |
+
return out.reshape(*bs, self.out_feat)
|
| 129 |
+
|
| 130 |
+
def set_rank(self, nr):
|
| 131 |
+
for i, c in enumerate(self.cores):
|
| 132 |
+
s = [slice(None)]*4
|
| 133 |
+
if i > 0: s[0] = slice(None, nr)
|
| 134 |
+
if i < self.ndim - 1: s[3] = slice(None, nr)
|
| 135 |
+
self.cores[i] = nn.Parameter(c[tuple(s)].clone())
|
| 136 |
+
|
| 137 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 138 |
+
# 2. QUANTUM ANGLE EMBEDDING (PennyLane)
|
| 139 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 140 |
+
|
| 141 |
+
class QuantumEmbed(nn.Module):
|
| 142 |
+
"""Angle embedding → variational circuit → PauliZ expectations."""
|
| 143 |
+
def __init__(self, n_q=4, layers=2, n_out=None):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.n_q, self.layers = n_q, layers
|
| 146 |
+
n_out = n_out or n_q
|
| 147 |
+
dev = qml.device("default.qubit", wires=n_q)
|
| 148 |
+
|
| 149 |
+
@qml.qnode(dev, interface="torch", diff_method="backprop")
|
| 150 |
+
def circ(inputs, w):
|
| 151 |
+
for i in range(n_q): qml.RX(inputs[..., i], wires=i)
|
| 152 |
+
for L in range(layers):
|
| 153 |
+
for i in range(n_q): qml.RY(w[L, i], wires=i)
|
| 154 |
+
for i in range(n_q-1): qml.CNOT(wires=[i, i+1])
|
| 155 |
+
if n_q > 2: qml.CNOT(wires=[n_q-1, 0])
|
| 156 |
+
return [qml.expval(qml.PauliZ(i)) for i in range(n_out)]
|
| 157 |
+
|
| 158 |
+
self.qlayer = qml.qnn.TorchLayer(circ, {"w": (layers, n_q)})
|
| 159 |
+
|
| 160 |
+
def forward(self, x): return self.qlayer(x)
|
| 161 |
+
|
| 162 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 163 |
+
# 3. TT FEED-FORWARD
|
| 164 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 165 |
+
|
| 166 |
+
class TTFFN(nn.Module):
|
| 167 |
+
def __init__(self, D, ff_mult=4, rank=8):
|
| 168 |
+
super().__init__()
|
| 169 |
+
E = D * ff_mult
|
| 170 |
+
self.up = TTLinear(auto_factor(D), auto_factor(E), rank, True)
|
| 171 |
+
self.down = TTLinear(auto_factor(E), auto_factor(D), rank, True)
|
| 172 |
+
def forward(self, x): return self.down(F.gelu(self.up(x)))
|
| 173 |
+
def set_rank(self, r): self.up.set_rank(r); self.down.set_rank(r)
|
| 174 |
+
|
| 175 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 176 |
+
# 4. RANK SCHEDULER
|
| 177 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 178 |
+
|
| 179 |
+
class RankScheduler(nn.Module):
|
| 180 |
+
"""rank = r_min + alpha * entropy (EMA-smoothed)"""
|
| 181 |
+
def __init__(self, mn=2, mx=16, a=2.0, sm=0.9):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.mn, self.mx = mn, mx
|
| 184 |
+
self.alpha = nn.Parameter(torch.tensor(a))
|
| 185 |
+
self.sm = sm
|
| 186 |
+
self.register_buffer('ema', torch.tensor(0.5))
|
| 187 |
+
self.register_buffer('cur', torch.tensor(float(mx)))
|
| 188 |
+
def forward(self, ent):
|
| 189 |
+
s = ent.mean().detach() if ent.numel()>1 else ent.detach()
|
| 190 |
+
self.ema = self.sm*self.ema + (1-self.sm)*s
|
| 191 |
+
raw = self.mn + self.alpha*self.ema
|
| 192 |
+
r = int(torch.clamp(raw, self.mn, self.mx).round().item())
|
| 193 |
+
if self.training: self.cur.fill_(r)
|
| 194 |
+
return r
|
| 195 |
+
@property
|
| 196 |
+
def current(self): return int(self.cur.item())
|
| 197 |
+
|
| 198 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 199 |
+
# 5. QUANTUM ROUTER
|
| 200 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 201 |
+
|
| 202 |
+
class QuantumRouter(nn.Module):
|
| 203 |
+
"""Learned gate: routes only hard tokens through quantum circuit."""
|
| 204 |
+
def __init__(self, D, qmod, thr=0.5):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.qmod = qmod
|
| 207 |
+
self.thr = thr
|
| 208 |
+
self.gate = nn.Sequential(
|
| 209 |
+
nn.Linear(D, D//4), nn.ReLU(), nn.Linear(D//4,1), nn.Sigmoid())
|
| 210 |
+
self.register_buffer('tot', torch.tensor(0.0))
|
| 211 |
+
self.register_buffer('qtok', torch.tensor(0.0))
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
B,S,D = x.shape
|
| 214 |
+
g = self.gate(x.reshape(-1,D)).squeeze(-1).reshape(B,S)
|
| 215 |
+
m = (g > self.thr).float()
|
| 216 |
+
if self.training:
|
| 217 |
+
m = m.detach() + g - g.detach()
|
| 218 |
+
xf = x.reshape(-1,D); mf = m.reshape(-1)
|
| 219 |
+
sel = xf[mf > 0.5]; out = xf.clone()
|
| 220 |
+
if sel.shape[0]>0:
|
| 221 |
+
qo = self.qmod(sel)
|
| 222 |
+
if qo.shape[-1]!=D:
|
| 223 |
+
if not hasattr(self,'_proj'):
|
| 224 |
+
self._proj = nn.Linear(qo.shape[-1],D).to(x.device)
|
| 225 |
+
qo = self._proj(qo)
|
| 226 |
+
out[mf > 0.5] = qo.to(out.dtype)
|
| 227 |
+
self.tot += B*S; self.qtok += m.sum()
|
| 228 |
+
return out.reshape(B,S,D), g
|
| 229 |
+
def sparsity(self):
|
| 230 |
+
if self.tot>0: return 1.0-(self.qtok/self.tot).item()
|
| 231 |
+
return 1.0
|
| 232 |
+
|
| 233 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 234 |
+
# 6. ATTENTION
|
| 235 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 236 |
+
|
| 237 |
+
class MHA(nn.Module):
|
| 238 |
+
def __init__(self, D, heads=4, drop=0.1):
|
| 239 |
+
super().__init__()
|
| 240 |
+
assert D%heads==0
|
| 241 |
+
self.h, self.hd = heads, D//heads
|
| 242 |
+
self.scale = self.hd**-0.5
|
| 243 |
+
self.qkv = nn.Linear(D, 3*D, bias=False)
|
| 244 |
+
self.out = nn.Linear(D, D)
|
| 245 |
+
self.drop = nn.Dropout(drop)
|
| 246 |
+
def forward(self, x, mask=None):
|
| 247 |
+
B,S,D = x.shape
|
| 248 |
+
qkv = self.qkv(x).reshape(B,S,3,self.h,self.hd).permute(2,0,3,1,4)
|
| 249 |
+
q,k,v = qkv[0], qkv[1], qkv[2]
|
| 250 |
+
a = (q@k.transpose(-2,-1))*self.scale
|
| 251 |
+
if mask is not None:
|
| 252 |
+
a = a.masked_fill(mask[:,None,None,:]==0, float('-inf'))
|
| 253 |
+
aw = F.softmax(a, dim=-1); aw = self.drop(aw)
|
| 254 |
+
o = (aw@v).transpose(1,2).reshape(B,S,D)
|
| 255 |
+
return self.out(o), aw
|
| 256 |
+
|
| 257 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 258 |
+
# 7. HYBRID BLOCK
|
| 259 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 260 |
+
|
| 261 |
+
class HybridBlock(nn.Module):
|
| 262 |
+
def __init__(self, cfg):
|
| 263 |
+
super().__init__()
|
| 264 |
+
D = cfg.d_model
|
| 265 |
+
self.a_norm = nn.LayerNorm(D)
|
| 266 |
+
self.attn = MHA(D, cfg.n_heads, cfg.dropout)
|
| 267 |
+
self.f_norm = nn.LayerNorm(D)
|
| 268 |
+
self.ffn = TTFFN(D, cfg.ff_mult, cfg.tt_rank)
|
| 269 |
+
self.qrouter = None
|
| 270 |
+
if cfg.q_qubits:
|
| 271 |
+
qc = QuantumEmbed(cfg.q_qubits, cfg.q_layers, cfg.q_qubits)
|
| 272 |
+
qw = nn.Sequential(nn.Linear(D, cfg.q_qubits), qc)
|
| 273 |
+
self.qrouter = QuantumRouter(D, qw)
|
| 274 |
+
self.rs = RankScheduler(cfg.min_rank, cfg.tt_rank, cfg.rank_alpha, cfg.rank_smoothing)
|
| 275 |
+
self.drop = nn.Dropout(cfg.dropout)
|
| 276 |
+
def forward(self, x, mask=None, adapt=True):
|
| 277 |
+
ao, aw = self.attn(self.a_norm(x), mask)
|
| 278 |
+
x = x + self.drop(ao)
|
| 279 |
+
eps=1e-8
|
| 280 |
+
ent = -torch.sum(aw*torch.log(aw+eps), dim=-1).mean(dim=-1).mean()
|
| 281 |
+
tr = self.rs(ent) if adapt else self.rs.mx
|
| 282 |
+
if adapt: self.ffn.set_rank(tr)
|
| 283 |
+
n = self.f_norm(x)
|
| 284 |
+
qs = 1.0
|
| 285 |
+
if self.qrouter is not None:
|
| 286 |
+
qo, _ = self.qrouter(n)
|
| 287 |
+
n = n + self.drop(qo - n.detach() + n)
|
| 288 |
+
qs = self.qrouter.sparsity()
|
| 289 |
+
x = x + self.drop(self.ffn(n))
|
| 290 |
+
return {'out':x, 'aw':aw, 'entropy':ent, 'rank':tr, 'qsparse':qs}
|
| 291 |
+
|
| 292 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 293 |
+
# 8. Q-TENSORFORMER MODEL
|
| 294 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 295 |
+
|
| 296 |
+
class QTensorFormer(nn.Module):
|
| 297 |
+
def __init__(self, cfg):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.cfg = cfg
|
| 300 |
+
self.tok = nn.Embedding(cfg.vocab, cfg.d_model)
|
| 301 |
+
self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02)
|
| 302 |
+
self.layers = nn.ModuleList([HybridBlock(cfg) for _ in range(cfg.n_layers)])
|
| 303 |
+
self.norm = nn.LayerNorm(cfg.d_model)
|
| 304 |
+
self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False)
|
| 305 |
+
self.head.weight = self.tok.weight
|
| 306 |
+
self._init()
|
| 307 |
+
def _init(self):
|
| 308 |
+
for p in self.parameters():
|
| 309 |
+
if p.dim()>=2: nn.init.xavier_uniform_(p)
|
| 310 |
+
def forward(self, ids, mask=None, adapt=True):
|
| 311 |
+
B,S = ids.shape
|
| 312 |
+
x = self.tok(ids) + self.pos[:,:S,:]
|
| 313 |
+
if mask is not None: mask = mask[:,None,None,:]
|
| 314 |
+
bos = []
|
| 315 |
+
for l in self.layers:
|
| 316 |
+
o = l(x, mask, adapt); x=o['out']; bos.append(o)
|
| 317 |
+
x = self.norm(x); logits = self.head(x)
|
| 318 |
+
ent = torch.stack([b['entropy'] for b in bos]).mean()
|
| 319 |
+
rk = sum(b['rank'] for b in bos)/len(bos)
|
| 320 |
+
qs = sum(b['qsparse'] for b in bos)/len(bos)
|
| 321 |
+
return {'logits':logits,'entropy':ent,'rank':rk,'qsparse':qs}
|
| 322 |
+
def loss(self, ids, mask=None, labels=None):
|
| 323 |
+
if labels is None: labels=ids.clone()
|
| 324 |
+
out = self(ids, mask)
|
| 325 |
+
sl = out['logits'][:,:-1].contiguous()
|
| 326 |
+
ll = labels[:,1:].contiguous()
|
| 327 |
+
l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100)
|
| 328 |
+
return {'loss':l,'ppl':torch.exp(l),'entropy':out['entropy'],'rank':out['rank'],'qsparse':out['qsparse']}
|
| 329 |
+
def nparams(self):
|
| 330 |
+
t = sum(p.numel() for p in self.parameters())
|
| 331 |
+
tr = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 332 |
+
return {'total':t,'trainable':tr}
|
| 333 |
+
|
| 334 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 335 |
+
# 9. BASELINE (identical architecture, dense FFN)
|
| 336 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 337 |
+
|
| 338 |
+
class Baseline(nn.Module):
|
| 339 |
+
def __init__(self, cfg):
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.cfg = cfg
|
| 342 |
+
self.tok = nn.Embedding(cfg.vocab, cfg.d_model)
|
| 343 |
+
self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02)
|
| 344 |
+
self.drop = nn.Dropout(cfg.dropout)
|
| 345 |
+
self.layers = nn.ModuleList()
|
| 346 |
+
for _ in range(cfg.n_layers):
|
| 347 |
+
self.layers.append(nn.ModuleDict({
|
| 348 |
+
'a_n': nn.LayerNorm(cfg.d_model),
|
| 349 |
+
'a': MHA(cfg.d_model, cfg.n_heads, cfg.dropout),
|
| 350 |
+
'f_n': nn.LayerNorm(cfg.d_model),
|
| 351 |
+
'ff': nn.Sequential(
|
| 352 |
+
nn.Linear(cfg.d_model, cfg.d_model*cfg.ff_mult),
|
| 353 |
+
nn.GELU(), nn.Dropout(cfg.dropout),
|
| 354 |
+
nn.Linear(cfg.d_model*cfg.ff_mult, cfg.d_model)),
|
| 355 |
+
}))
|
| 356 |
+
self.norm = nn.LayerNorm(cfg.d_model)
|
| 357 |
+
self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False)
|
| 358 |
+
self.head.weight = self.tok.weight
|
| 359 |
+
self._init()
|
| 360 |
+
def _init(self):
|
| 361 |
+
for p in self.parameters():
|
| 362 |
+
if p.dim()>=2: nn.init.xavier_uniform_(p)
|
| 363 |
+
def forward(self, ids, mask=None):
|
| 364 |
+
B,S = ids.shape
|
| 365 |
+
x = self.tok(ids)+self.pos[:,:S,:]; x=self.drop(x)
|
| 366 |
+
m = mask[:,None,None,:] if mask is not None else None
|
| 367 |
+
for l in self.layers:
|
| 368 |
+
ao,_ = l['a'](l['a_n'](x),m); x=x+self.drop(ao)
|
| 369 |
+
x = x+self.drop(l['ff'](l['f_n'](x)))
|
| 370 |
+
return {'logits':self.head(self.norm(x))}
|
| 371 |
+
def loss(self, ids, mask=None, labels=None):
|
| 372 |
+
if labels is None: labels=ids.clone()
|
| 373 |
+
out = self(ids, mask)
|
| 374 |
+
sl = out['logits'][:,:-1].contiguous()
|
| 375 |
+
ll = labels[:,1:].contiguous()
|
| 376 |
+
l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100)
|
| 377 |
+
return {'loss':l,'ppl':torch.exp(l)}
|
| 378 |
+
def nparams(self):
|
| 379 |
+
t = sum(p.numel() for p in self.parameters())
|
| 380 |
+
tr = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 381 |
+
return {'total':t,'trainable':tr}
|
| 382 |
+
|
| 383 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 384 |
+
# 10. TRAINING UTILITIES
|
| 385 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 386 |
+
|
| 387 |
+
def make_data(vocab=1000, seq=64, n=500, bs=16):
|
| 388 |
+
d = torch.randint(1, vocab, (n, seq))
|
| 389 |
+
ds = torch.utils.data.TensorDataset(d)
|
| 390 |
+
return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True,
|
| 391 |
+
collate_fn=lambda batch: {'input_ids': torch.stack([item[0] for item in batch])})
|
| 392 |
+
|
| 393 |
+
def train_epoch(model, dl, opt, sched, e, tag="M"):
|
| 394 |
+
model.train(); tl,tp,nb = 0.0,0.0,0; ex={}
|
| 395 |
+
for b in dl:
|
| 396 |
+
ids = b['input_ids']; m = b.get('attention_mask')
|
| 397 |
+
opt.zero_grad()
|
| 398 |
+
out = model.loss(ids, m); out['loss'].backward()
|
| 399 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 400 |
+
opt.step()
|
| 401 |
+
if sched: sched.step()
|
| 402 |
+
tl += out['loss'].item(); tp += out['ppl'].item(); nb += 1
|
| 403 |
+
for k in ['entropy','rank','qsparse']:
|
| 404 |
+
if k in out: ex[k]=ex.get(k,0.0)+(out[k].item() if isinstance(out[k],torch.Tensor) else out[k])
|
| 405 |
+
al,ap = tl/max(nb,1), tp/max(nb,1)
|
| 406 |
+
s = f"[{tag}] E{e:2d} loss={al:.4f} ppl={ap:6.1f}"
|
| 407 |
+
for k,v in ex.items(): s+=f" {k}={v/max(nb,1):.3f}"
|
| 408 |
+
print(s); return al,ap
|
| 409 |
+
|
| 410 |
+
@torch.no_grad()
|
| 411 |
+
def evaluate(model, dl):
|
| 412 |
+
model.eval(); tl,tp,nb=0.0,0.0,0
|
| 413 |
+
for b in dl:
|
| 414 |
+
ids=b['input_ids']; m=b.get('attention_mask')
|
| 415 |
+
out=model.loss(ids,m); tl+=out['loss'].item(); tp+=out['ppl'].item(); nb+=1
|
| 416 |
+
return tl/max(nb,1), tp/max(nb,1)
|
| 417 |
+
|
| 418 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 419 |
+
# 11. MAIN BENCHMARK
|
| 420 |
+
# ═════════════════════════════════════════════════════════════════════
|
| 421 |
+
|
| 422 |
+
def main():
|
| 423 |
+
torch.manual_seed(42)
|
| 424 |
+
cfg = CFG(d_model=64, n_layers=2, n_heads=4, tt_rank=8,
|
| 425 |
+
q_qubits=4, q_sparsity=0.3, vocab=1000, max_seq=64)
|
| 426 |
+
|
| 427 |
+
print(f"Config: d={cfg.d_model} layers={cfg.n_layers} heads={cfg.n_heads} rank={cfg.tt_rank}")
|
| 428 |
+
print(f"Quantum: qubits={cfg.q_qubits} sparsity={cfg.q_sparsity}")
|
| 429 |
+
print(f"Tensor FFN: ON\n")
|
| 430 |
+
|
| 431 |
+
qt = QTensorFormer(cfg)
|
| 432 |
+
bl = Baseline(cfg)
|
| 433 |
+
|
| 434 |
+
pq = qt.nparams(); pb = bl.nparams()
|
| 435 |
+
print(f"Q-TensorFormer params: {pq['trainable']:>10,}")
|
| 436 |
+
print(f"Baseline params: {pb['trainable']:>10,}")
|
| 437 |
+
print(f"Compression ratio: {pb['trainable']/max(pq['trainable'],1):>10.1f}x\n")
|
| 438 |
+
|
| 439 |
+
train_dl = make_data(cfg.vocab, cfg.max_seq, 500, 16)
|
| 440 |
+
val_dl = make_data(cfg.vocab, cfg.max_seq, 100, 16)
|
| 441 |
+
E = 8
|
| 442 |
+
|
| 443 |
+
print("=" * 50)
|
| 444 |
+
print(" TRAINING Q-TENSORFORMER")
|
| 445 |
+
print("=" * 50)
|
| 446 |
+
oq = torch.optim.AdamW(qt.parameters(), lr=cfg.lr)
|
| 447 |
+
sq = torch.optim.lr_scheduler.CosineAnnealingLR(oq, E*len(train_dl))
|
| 448 |
+
for e in range(1, E+1): train_epoch(qt, train_dl, oq, sq, e, "Q-TF")
|
| 449 |
+
|
| 450 |
+
print("\n" + "=" * 50)
|
| 451 |
+
print(" TRAINING BASELINE")
|
| 452 |
+
print("=" * 50)
|
| 453 |
+
ob = torch.optim.AdamW(bl.parameters(), lr=cfg.lr)
|
| 454 |
+
sb = torch.optim.lr_scheduler.CosineAnnealingLR(ob, E*len(train_dl))
|
| 455 |
+
for e in range(1, E+1): train_epoch(bl, train_dl, ob, sb, e, "BSL")
|
| 456 |
+
|
| 457 |
+
ql,qp = evaluate(qt, val_dl)
|
| 458 |
+
bl_val,bp = evaluate(bl, val_dl)
|
| 459 |
+
|
| 460 |
+
torch.save(qt.state_dict(), '/tmp/qt.pt')
|
| 461 |
+
torch.save(bl.state_dict(), '/tmp/bl.pt')
|
| 462 |
+
qsz = os.path.getsize('/tmp/qt.pt')/(1024*1024)
|
| 463 |
+
bsz = os.path.getsize('/tmp/bl.pt')/(1024*1024)
|
| 464 |
+
|
| 465 |
+
print("\n" + "=" * 65)
|
| 466 |
+
print(" RESULTS")
|
| 467 |
+
print("=" * 65)
|
| 468 |
+
print(f"{'Metric':<30} {'Q-TensorFormer':>15} {'Baseline':>15}")
|
| 469 |
+
print("-" * 60)
|
| 470 |
+
print(f"{'Parameters':<30} {pq['trainable']:>13,} {pb['trainable']:>13,}")
|
| 471 |
+
print(f"{'Val Loss':<30} {ql:>15.4f} {bl_val:>15.4f}")
|
| 472 |
+
print(f"{'Val Perplexity':<30} {qp:>15.2f} {bp:>15.2f}")
|
| 473 |
+
print(f"{'Model Size (MB)':<30} {qsz:>15.1f} {bsz:>15.1f}")
|
| 474 |
+
|
| 475 |
+
ps = (1-pq['trainable']/pb['trainable'])*100
|
| 476 |
+
ss = (1-qsz/bsz)*100
|
| 477 |
+
pr = qp/bp
|
| 478 |
+
print(f"\nParameter reduction: {ps:.1f}%")
|
| 479 |
+
print(f"Size reduction: {ss:.1f}%")
|
| 480 |
+
print(f"PPL ratio (Q-TF/BL): {pr:.2f}x")
|
| 481 |
+
|
| 482 |
+
if pr < 1.1:
|
| 483 |
+
print(f"\n >> VERDICT: Significant compression with minimal quality loss! <<")
|
| 484 |
+
elif pr < 1.3:
|
| 485 |
+
print(f"\n >> VERDICT: Moderate trade-off — compression worth the cost <<")
|
| 486 |
+
else:
|
| 487 |
+
print(f"\n >> VERDICT: Quality gap too large, needs tuning <<")
|
| 488 |
+
|
| 489 |
+
print("\nDone!")
|
| 490 |
+
return {'params_q':pq['trainable'],'params_b':pb['trainable'],'qloss':ql,'qppl':qp,'bloss':bl_val,'bppl':bp,'qsz':qsz,'bsz':bsz,'comp':ps,'sred':ss,'ppl_ratio':pr}
|
| 491 |
+
|
| 492 |
+
if __name__ == '__main__':
|
| 493 |
+
results = main()
|