| |
| """ |
| Q-TensorFormer: Quantum-Enhanced Tensor Network LLM Compression Engine |
| ======================================================================= |
| Hybrid quantum-tensor transformer with: |
| - Pure PyTorch Tensor-Train FFN layers (no compiled deps) |
| - PennyLane quantum angle encoding with TorchLayer |
| - Entanglement-guided adaptive rank scheduling |
| - Selective quantum routing (only "hard" tokens) |
| - Full benchmark against identical-architecture baseline |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math, os |
| from typing import Optional, Tuple |
| from dataclasses import dataclass |
|
|
| import pennylane as qml |
|
|
| print("=" * 65) |
| print(" Q-TENSORFORMER: Quantum-Tensor LLM Compressor") |
| print("=" * 65) |
| print(f" PyTorch {torch.__version__} | PennyLane {qml.__version__}") |
| print() |
|
|
| |
| |
| |
|
|
| @dataclass |
| class CFG: |
| d_model: int = 64 |
| n_heads: int = 4 |
| n_layers: int = 2 |
| ff_mult: int = 4 |
| max_seq: int = 64 |
| vocab: int = 1000 |
| tt_rank: int = 8 |
| min_rank: int = 2 |
| q_qubits: int = 4 |
| q_layers: int = 2 |
| q_sparsity: float = 0.3 |
| dropout: float = 0.1 |
| lr: float = 3e-4 |
| rank_alpha: float = 2.0 |
| rank_smoothing: float = 0.9 |
|
|
| |
| |
| |
|
|
| def auto_factor(n, max_f=4): |
| if n <= 1: return (1, 1) |
| f, r = [], n |
| for p in [2,2,2,2,2,3,3,5,7]: |
| while r % p == 0 and len(f) < max_f: |
| f.append(p); r //= p |
| if r > 1: |
| if len(f) < max_f: f.append(r) |
| else: f[-1] *= r |
| while len(f) < 2: f.insert(0, 1) |
| return tuple(f[:max_f]) |
|
|
| class TTLinear(nn.Module): |
| """Tensor-Train decomposed linear layer. Pure PyTorch, zero compiled deps.""" |
| def __init__(self, in_shape, out_shape, rank=8, bias=True): |
| super().__init__() |
| in_shape = tuple(in_shape) |
| out_shape = tuple(out_shape) |
| max_d = max(len(in_shape), len(out_shape)) |
| in_shape = (1,) * (max_d - len(in_shape)) + in_shape |
| out_shape = (1,) * (max_d - len(out_shape)) + out_shape |
| assert len(in_shape) == len(out_shape) |
| self.in_shape, self.out_shape = in_shape, out_shape |
| self.rank, self.ndim = rank, len(in_shape) |
| self.in_feat = math.prod(in_shape) |
| self.out_feat = math.prod(out_shape) |
| self.cores = nn.ParameterList() |
| for k in range(self.ndim): |
| rl = 1 if k == 0 else rank |
| rr = 1 if k == self.ndim - 1 else rank |
| c = torch.empty(rl, out_shape[k], in_shape[k], rr) |
| bnd = math.sqrt(6.0 / max(1, rl*in_shape[k] + rr*out_shape[k])) |
| nn.init.uniform_(c, -bnd, bnd) |
| self.cores.append(c) |
| self.bias = nn.Parameter(torch.zeros(self.out_feat)) if bias else None |
| tp = sum(c.numel() for c in self.cores) + (self.bias.numel() if bias else 0) |
| self.compr = (self.in_feat * self.out_feat) / max(tp, 1) |
|
|
| def forward(self, x): |
| bs = x.shape[:-1] |
| B = math.prod(bs) |
| x = x.reshape(B, self.in_feat) |
| state = x.reshape(B, *self.in_shape) |
|
|
| for k in range(self.ndim): |
| core = self.cores[k] |
| r_k, o_k, i_k, r_kp1 = core.shape |
|
|
| if k == 0: |
| rest = math.prod(self.in_shape[1:]) |
| s = state.reshape(B, i_k, rest) |
| cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1) |
| s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1)) |
| s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1) |
| state = s.reshape(B, r_kp1, -1) |
|
|
| elif k == self.ndim - 1: |
| prev_os = math.prod(self.out_shape[:k]) |
| s = state.reshape(B, r_k, prev_os, i_k) |
| cm = core.squeeze(-1) |
| s = torch.einsum('brpi,roi->bpo', s, cm) |
| state = s.reshape(B, prev_os * o_k) |
|
|
| else: |
| prev_os = math.prod(self.out_shape[:k]) |
| rest_in = math.prod(self.in_shape[k+1:]) |
| s = state.reshape(B, r_k, prev_os * i_k * rest_in) |
| s = s.reshape(B, r_k, prev_os, i_k, rest_in) |
| s = torch.einsum('brpix,roiq->bpoqx', s, core) |
| s = s.permute(0, 3, 1, 2, 4) |
| state = s.reshape(B, r_kp1, prev_os * o_k * rest_in) |
|
|
| out = state.reshape(B, self.out_feat) |
| if self.bias is not None: out = out + self.bias |
| return out.reshape(*bs, self.out_feat) |
|
|
| def set_rank(self, nr): |
| for i, c in enumerate(self.cores): |
| s = [slice(None)]*4 |
| if i > 0: s[0] = slice(None, nr) |
| if i < self.ndim - 1: s[3] = slice(None, nr) |
| self.cores[i] = nn.Parameter(c[tuple(s)].clone()) |
|
|
| |
| |
| |
|
|
| class QuantumEmbed(nn.Module): |
| """Angle embedding → variational circuit → PauliZ expectations.""" |
| def __init__(self, n_q=4, layers=2, n_out=None): |
| super().__init__() |
| self.n_q, self.layers = n_q, layers |
| n_out = n_out or n_q |
| dev = qml.device("default.qubit", wires=n_q) |
|
|
| @qml.qnode(dev, interface="torch", diff_method="backprop") |
| def circ(inputs, w): |
| for i in range(n_q): qml.RX(inputs[..., i], wires=i) |
| for L in range(layers): |
| for i in range(n_q): qml.RY(w[L, i], wires=i) |
| for i in range(n_q-1): qml.CNOT(wires=[i, i+1]) |
| if n_q > 2: qml.CNOT(wires=[n_q-1, 0]) |
| return [qml.expval(qml.PauliZ(i)) for i in range(n_out)] |
|
|
| self.qlayer = qml.qnn.TorchLayer(circ, {"w": (layers, n_q)}) |
|
|
| def forward(self, x): return self.qlayer(x) |
|
|
| |
| |
| |
|
|
| class TTFFN(nn.Module): |
| def __init__(self, D, ff_mult=4, rank=8): |
| super().__init__() |
| E = D * ff_mult |
| self.up = TTLinear(auto_factor(D), auto_factor(E), rank, True) |
| self.down = TTLinear(auto_factor(E), auto_factor(D), rank, True) |
| def forward(self, x): return self.down(F.gelu(self.up(x))) |
| def set_rank(self, r): self.up.set_rank(r); self.down.set_rank(r) |
|
|
| |
| |
| |
|
|
| class RankScheduler(nn.Module): |
| """rank = r_min + alpha * entropy (EMA-smoothed)""" |
| def __init__(self, mn=2, mx=16, a=2.0, sm=0.9): |
| super().__init__() |
| self.mn, self.mx = mn, mx |
| self.alpha = nn.Parameter(torch.tensor(a)) |
| self.sm = sm |
| self.register_buffer('ema', torch.tensor(0.5)) |
| self.register_buffer('cur', torch.tensor(float(mx))) |
| def forward(self, ent): |
| s = ent.mean().detach() if ent.numel()>1 else ent.detach() |
| self.ema = self.sm*self.ema + (1-self.sm)*s |
| raw = self.mn + self.alpha*self.ema |
| r = int(torch.clamp(raw, self.mn, self.mx).round().item()) |
| if self.training: self.cur.fill_(r) |
| return r |
| @property |
| def current(self): return int(self.cur.item()) |
|
|
| |
| |
| |
|
|
| class QuantumRouter(nn.Module): |
| """Learned gate: routes only hard tokens through quantum circuit.""" |
| def __init__(self, D, qmod, thr=0.5): |
| super().__init__() |
| self.qmod = qmod |
| self.thr = thr |
| self.gate = nn.Sequential( |
| nn.Linear(D, D//4), nn.ReLU(), nn.Linear(D//4,1), nn.Sigmoid()) |
| self.register_buffer('tot', torch.tensor(0.0)) |
| self.register_buffer('qtok', torch.tensor(0.0)) |
| def forward(self, x): |
| B,S,D = x.shape |
| g = self.gate(x.reshape(-1,D)).squeeze(-1).reshape(B,S) |
| m = (g > self.thr).float() |
| if self.training: |
| m = m.detach() + g - g.detach() |
| xf = x.reshape(-1,D); mf = m.reshape(-1) |
| sel = xf[mf > 0.5]; out = xf.clone() |
| if sel.shape[0]>0: |
| qo = self.qmod(sel) |
| if qo.shape[-1]!=D: |
| if not hasattr(self,'_proj'): |
| self._proj = nn.Linear(qo.shape[-1],D).to(x.device) |
| qo = self._proj(qo) |
| out[mf > 0.5] = qo.to(out.dtype) |
| self.tot += B*S; self.qtok += m.sum() |
| return out.reshape(B,S,D), g |
| def sparsity(self): |
| if self.tot>0: return 1.0-(self.qtok/self.tot).item() |
| return 1.0 |
|
|
| |
| |
| |
|
|
| class MHA(nn.Module): |
| def __init__(self, D, heads=4, drop=0.1): |
| super().__init__() |
| assert D%heads==0 |
| self.h, self.hd = heads, D//heads |
| self.scale = self.hd**-0.5 |
| self.qkv = nn.Linear(D, 3*D, bias=False) |
| self.out = nn.Linear(D, D) |
| self.drop = nn.Dropout(drop) |
| def forward(self, x, mask=None): |
| B,S,D = x.shape |
| qkv = self.qkv(x).reshape(B,S,3,self.h,self.hd).permute(2,0,3,1,4) |
| q,k,v = qkv[0], qkv[1], qkv[2] |
| a = (q@k.transpose(-2,-1))*self.scale |
| if mask is not None: |
| a = a.masked_fill(mask[:,None,None,:]==0, float('-inf')) |
| aw = F.softmax(a, dim=-1); aw = self.drop(aw) |
| o = (aw@v).transpose(1,2).reshape(B,S,D) |
| return self.out(o), aw |
|
|
| |
| |
| |
|
|
| class HybridBlock(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| D = cfg.d_model |
| self.a_norm = nn.LayerNorm(D) |
| self.attn = MHA(D, cfg.n_heads, cfg.dropout) |
| self.f_norm = nn.LayerNorm(D) |
| self.ffn = TTFFN(D, cfg.ff_mult, cfg.tt_rank) |
| self.qrouter = None |
| if cfg.q_qubits: |
| qc = QuantumEmbed(cfg.q_qubits, cfg.q_layers, cfg.q_qubits) |
| qw = nn.Sequential(nn.Linear(D, cfg.q_qubits), qc) |
| self.qrouter = QuantumRouter(D, qw) |
| self.rs = RankScheduler(cfg.min_rank, cfg.tt_rank, cfg.rank_alpha, cfg.rank_smoothing) |
| self.drop = nn.Dropout(cfg.dropout) |
| def forward(self, x, mask=None, adapt=True): |
| ao, aw = self.attn(self.a_norm(x), mask) |
| x = x + self.drop(ao) |
| eps=1e-8 |
| ent = -torch.sum(aw*torch.log(aw+eps), dim=-1).mean(dim=-1).mean() |
| tr = self.rs(ent) if adapt else self.rs.mx |
| if adapt: self.ffn.set_rank(tr) |
| n = self.f_norm(x) |
| qs = 1.0 |
| if self.qrouter is not None: |
| qo, _ = self.qrouter(n) |
| n = n + self.drop(qo - n.detach() + n) |
| qs = self.qrouter.sparsity() |
| x = x + self.drop(self.ffn(n)) |
| return {'out':x, 'aw':aw, 'entropy':ent, 'rank':tr, 'qsparse':qs} |
|
|
| |
| |
| |
|
|
| class QTensorFormer(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.tok = nn.Embedding(cfg.vocab, cfg.d_model) |
| self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02) |
| self.layers = nn.ModuleList([HybridBlock(cfg) for _ in range(cfg.n_layers)]) |
| self.norm = nn.LayerNorm(cfg.d_model) |
| self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False) |
| self.head.weight = self.tok.weight |
| self._init() |
| def _init(self): |
| for p in self.parameters(): |
| if p.dim()>=2: nn.init.xavier_uniform_(p) |
| def forward(self, ids, mask=None, adapt=True): |
| B,S = ids.shape |
| x = self.tok(ids) + self.pos[:,:S,:] |
| if mask is not None: mask = mask[:,None,None,:] |
| bos = [] |
| for l in self.layers: |
| o = l(x, mask, adapt); x=o['out']; bos.append(o) |
| x = self.norm(x); logits = self.head(x) |
| ent = torch.stack([b['entropy'] for b in bos]).mean() |
| rk = sum(b['rank'] for b in bos)/len(bos) |
| qs = sum(b['qsparse'] for b in bos)/len(bos) |
| return {'logits':logits,'entropy':ent,'rank':rk,'qsparse':qs} |
| def loss(self, ids, mask=None, labels=None): |
| if labels is None: labels=ids.clone() |
| out = self(ids, mask) |
| sl = out['logits'][:,:-1].contiguous() |
| ll = labels[:,1:].contiguous() |
| l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100) |
| return {'loss':l,'ppl':torch.exp(l),'entropy':out['entropy'],'rank':out['rank'],'qsparse':out['qsparse']} |
| def nparams(self): |
| t = sum(p.numel() for p in self.parameters()) |
| tr = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return {'total':t,'trainable':tr} |
|
|
| |
| |
| |
|
|
| class Baseline(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.tok = nn.Embedding(cfg.vocab, cfg.d_model) |
| self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02) |
| self.drop = nn.Dropout(cfg.dropout) |
| self.layers = nn.ModuleList() |
| for _ in range(cfg.n_layers): |
| self.layers.append(nn.ModuleDict({ |
| 'a_n': nn.LayerNorm(cfg.d_model), |
| 'a': MHA(cfg.d_model, cfg.n_heads, cfg.dropout), |
| 'f_n': nn.LayerNorm(cfg.d_model), |
| 'ff': nn.Sequential( |
| nn.Linear(cfg.d_model, cfg.d_model*cfg.ff_mult), |
| nn.GELU(), nn.Dropout(cfg.dropout), |
| nn.Linear(cfg.d_model*cfg.ff_mult, cfg.d_model)), |
| })) |
| self.norm = nn.LayerNorm(cfg.d_model) |
| self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False) |
| self.head.weight = self.tok.weight |
| self._init() |
| def _init(self): |
| for p in self.parameters(): |
| if p.dim()>=2: nn.init.xavier_uniform_(p) |
| def forward(self, ids, mask=None): |
| B,S = ids.shape |
| x = self.tok(ids)+self.pos[:,:S,:]; x=self.drop(x) |
| m = mask[:,None,None,:] if mask is not None else None |
| for l in self.layers: |
| ao,_ = l['a'](l['a_n'](x),m); x=x+self.drop(ao) |
| x = x+self.drop(l['ff'](l['f_n'](x))) |
| return {'logits':self.head(self.norm(x))} |
| def loss(self, ids, mask=None, labels=None): |
| if labels is None: labels=ids.clone() |
| out = self(ids, mask) |
| sl = out['logits'][:,:-1].contiguous() |
| ll = labels[:,1:].contiguous() |
| l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100) |
| return {'loss':l,'ppl':torch.exp(l)} |
| def nparams(self): |
| t = sum(p.numel() for p in self.parameters()) |
| tr = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return {'total':t,'trainable':tr} |
|
|
| |
| |
| |
|
|
| def make_data(vocab=1000, seq=64, n=500, bs=16): |
| d = torch.randint(1, vocab, (n, seq)) |
| ds = torch.utils.data.TensorDataset(d) |
| return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, |
| collate_fn=lambda batch: {'input_ids': torch.stack([item[0] for item in batch])}) |
|
|
| def train_epoch(model, dl, opt, sched, e, tag="M"): |
| model.train(); tl,tp,nb = 0.0,0.0,0; ex={} |
| for b in dl: |
| ids = b['input_ids']; m = b.get('attention_mask') |
| opt.zero_grad() |
| out = model.loss(ids, m); out['loss'].backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| if sched: sched.step() |
| tl += out['loss'].item(); tp += out['ppl'].item(); nb += 1 |
| for k in ['entropy','rank','qsparse']: |
| if k in out: ex[k]=ex.get(k,0.0)+(out[k].item() if isinstance(out[k],torch.Tensor) else out[k]) |
| al,ap = tl/max(nb,1), tp/max(nb,1) |
| s = f"[{tag}] E{e:2d} loss={al:.4f} ppl={ap:6.1f}" |
| for k,v in ex.items(): s+=f" {k}={v/max(nb,1):.3f}" |
| print(s); return al,ap |
|
|
| @torch.no_grad() |
| def evaluate(model, dl): |
| model.eval(); tl,tp,nb=0.0,0.0,0 |
| for b in dl: |
| ids=b['input_ids']; m=b.get('attention_mask') |
| out=model.loss(ids,m); tl+=out['loss'].item(); tp+=out['ppl'].item(); nb+=1 |
| return tl/max(nb,1), tp/max(nb,1) |
|
|
| |
| |
| |
|
|
| def main(): |
| torch.manual_seed(42) |
| cfg = CFG(d_model=64, n_layers=2, n_heads=4, tt_rank=8, |
| q_qubits=4, q_sparsity=0.3, vocab=1000, max_seq=64) |
|
|
| print(f"Config: d={cfg.d_model} layers={cfg.n_layers} heads={cfg.n_heads} rank={cfg.tt_rank}") |
| print(f"Quantum: qubits={cfg.q_qubits} sparsity={cfg.q_sparsity}") |
| print(f"Tensor FFN: ON\n") |
|
|
| qt = QTensorFormer(cfg) |
| bl = Baseline(cfg) |
|
|
| pq = qt.nparams(); pb = bl.nparams() |
| print(f"Q-TensorFormer params: {pq['trainable']:>10,}") |
| print(f"Baseline params: {pb['trainable']:>10,}") |
| print(f"Compression ratio: {pb['trainable']/max(pq['trainable'],1):>10.1f}x\n") |
|
|
| train_dl = make_data(cfg.vocab, cfg.max_seq, 500, 16) |
| val_dl = make_data(cfg.vocab, cfg.max_seq, 100, 16) |
| E = 8 |
|
|
| print("=" * 50) |
| print(" TRAINING Q-TENSORFORMER") |
| print("=" * 50) |
| oq = torch.optim.AdamW(qt.parameters(), lr=cfg.lr) |
| sq = torch.optim.lr_scheduler.CosineAnnealingLR(oq, E*len(train_dl)) |
| for e in range(1, E+1): train_epoch(qt, train_dl, oq, sq, e, "Q-TF") |
|
|
| print("\n" + "=" * 50) |
| print(" TRAINING BASELINE") |
| print("=" * 50) |
| ob = torch.optim.AdamW(bl.parameters(), lr=cfg.lr) |
| sb = torch.optim.lr_scheduler.CosineAnnealingLR(ob, E*len(train_dl)) |
| for e in range(1, E+1): train_epoch(bl, train_dl, ob, sb, e, "BSL") |
|
|
| ql,qp = evaluate(qt, val_dl) |
| bl_val,bp = evaluate(bl, val_dl) |
|
|
| torch.save(qt.state_dict(), '/tmp/qt.pt') |
| torch.save(bl.state_dict(), '/tmp/bl.pt') |
| qsz = os.path.getsize('/tmp/qt.pt')/(1024*1024) |
| bsz = os.path.getsize('/tmp/bl.pt')/(1024*1024) |
|
|
| print("\n" + "=" * 65) |
| print(" RESULTS") |
| print("=" * 65) |
| print(f"{'Metric':<30} {'Q-TensorFormer':>15} {'Baseline':>15}") |
| print("-" * 60) |
| print(f"{'Parameters':<30} {pq['trainable']:>13,} {pb['trainable']:>13,}") |
| print(f"{'Val Loss':<30} {ql:>15.4f} {bl_val:>15.4f}") |
| print(f"{'Val Perplexity':<30} {qp:>15.2f} {bp:>15.2f}") |
| print(f"{'Model Size (MB)':<30} {qsz:>15.1f} {bsz:>15.1f}") |
|
|
| ps = (1-pq['trainable']/pb['trainable'])*100 |
| ss = (1-qsz/bsz)*100 |
| pr = qp/bp |
| print(f"\nParameter reduction: {ps:.1f}%") |
| print(f"Size reduction: {ss:.1f}%") |
| print(f"PPL ratio (Q-TF/BL): {pr:.2f}x") |
|
|
| if pr < 1.1: |
| print(f"\n >> VERDICT: Significant compression with minimal quality loss! <<") |
| elif pr < 1.3: |
| print(f"\n >> VERDICT: Moderate trade-off — compression worth the cost <<") |
| else: |
| print(f"\n >> VERDICT: Quality gap too large, needs tuning <<") |
|
|
| print("\nDone!") |
| return {'params_q':pq['trainable'],'params_b':pb['trainable'],'qloss':ql,'qppl':qp,'bloss':bl_val,'bppl':bp,'qsz':qsz,'bsz':bsz,'comp':ps,'sred':ss,'ppl_ratio':pr} |
|
|
| if __name__ == '__main__': |
| results = main() |
|
|