#!/usr/bin/env python3 """ Q-TensorFormer: Quantum-Enhanced Tensor Network LLM Compression Engine ======================================================================= Hybrid quantum-tensor transformer with: - Pure PyTorch Tensor-Train FFN layers (no compiled deps) - PennyLane quantum angle encoding with TorchLayer - Entanglement-guided adaptive rank scheduling - Selective quantum routing (only "hard" tokens) - Full benchmark against identical-architecture baseline """ import torch import torch.nn as nn import torch.nn.functional as F import math, os from typing import Optional, Tuple from dataclasses import dataclass import pennylane as qml print("=" * 65) print(" Q-TENSORFORMER: Quantum-Tensor LLM Compressor") print("=" * 65) print(f" PyTorch {torch.__version__} | PennyLane {qml.__version__}") print() # ═════════════════════════════════════════════════════════════════════ # CONFIG # ═════════════════════════════════════════════════════════════════════ @dataclass class CFG: d_model: int = 64 n_heads: int = 4 n_layers: int = 2 ff_mult: int = 4 max_seq: int = 64 vocab: int = 1000 tt_rank: int = 8 min_rank: int = 2 q_qubits: int = 4 q_layers: int = 2 q_sparsity: float = 0.3 dropout: float = 0.1 lr: float = 3e-4 rank_alpha: float = 2.0 rank_smoothing: float = 0.9 # ═════════════════════════════════════════════════════════════════════ # 1. PURE PYTORCH TENSOR-TRAIN LINEAR LAYER # ═════════════════════════════════════════════════════════════════════ def auto_factor(n, max_f=4): if n <= 1: return (1, 1) f, r = [], n for p in [2,2,2,2,2,3,3,5,7]: while r % p == 0 and len(f) < max_f: f.append(p); r //= p if r > 1: if len(f) < max_f: f.append(r) else: f[-1] *= r while len(f) < 2: f.insert(0, 1) return tuple(f[:max_f]) class TTLinear(nn.Module): """Tensor-Train decomposed linear layer. Pure PyTorch, zero compiled deps.""" def __init__(self, in_shape, out_shape, rank=8, bias=True): super().__init__() in_shape = tuple(in_shape) out_shape = tuple(out_shape) max_d = max(len(in_shape), len(out_shape)) in_shape = (1,) * (max_d - len(in_shape)) + in_shape out_shape = (1,) * (max_d - len(out_shape)) + out_shape assert len(in_shape) == len(out_shape) self.in_shape, self.out_shape = in_shape, out_shape self.rank, self.ndim = rank, len(in_shape) self.in_feat = math.prod(in_shape) self.out_feat = math.prod(out_shape) self.cores = nn.ParameterList() for k in range(self.ndim): rl = 1 if k == 0 else rank rr = 1 if k == self.ndim - 1 else rank c = torch.empty(rl, out_shape[k], in_shape[k], rr) bnd = math.sqrt(6.0 / max(1, rl*in_shape[k] + rr*out_shape[k])) nn.init.uniform_(c, -bnd, bnd) self.cores.append(c) self.bias = nn.Parameter(torch.zeros(self.out_feat)) if bias else None tp = sum(c.numel() for c in self.cores) + (self.bias.numel() if bias else 0) self.compr = (self.in_feat * self.out_feat) / max(tp, 1) def forward(self, x): bs = x.shape[:-1] B = math.prod(bs) x = x.reshape(B, self.in_feat) state = x.reshape(B, *self.in_shape) for k in range(self.ndim): core = self.cores[k] r_k, o_k, i_k, r_kp1 = core.shape if k == 0: rest = math.prod(self.in_shape[1:]) s = state.reshape(B, i_k, rest) cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1) s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1)) s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1) state = s.reshape(B, r_kp1, -1) elif k == self.ndim - 1: prev_os = math.prod(self.out_shape[:k]) s = state.reshape(B, r_k, prev_os, i_k) cm = core.squeeze(-1) s = torch.einsum('brpi,roi->bpo', s, cm) state = s.reshape(B, prev_os * o_k) else: prev_os = math.prod(self.out_shape[:k]) rest_in = math.prod(self.in_shape[k+1:]) s = state.reshape(B, r_k, prev_os * i_k * rest_in) s = s.reshape(B, r_k, prev_os, i_k, rest_in) s = torch.einsum('brpix,roiq->bpoqx', s, core) s = s.permute(0, 3, 1, 2, 4) state = s.reshape(B, r_kp1, prev_os * o_k * rest_in) out = state.reshape(B, self.out_feat) if self.bias is not None: out = out + self.bias return out.reshape(*bs, self.out_feat) def set_rank(self, nr): for i, c in enumerate(self.cores): s = [slice(None)]*4 if i > 0: s[0] = slice(None, nr) if i < self.ndim - 1: s[3] = slice(None, nr) self.cores[i] = nn.Parameter(c[tuple(s)].clone()) # ═════════════════════════════════════════════════════════════════════ # 2. QUANTUM ANGLE EMBEDDING (PennyLane) # ═════════════════════════════════════════════════════════════════════ class QuantumEmbed(nn.Module): """Angle embedding → variational circuit → PauliZ expectations.""" def __init__(self, n_q=4, layers=2, n_out=None): super().__init__() self.n_q, self.layers = n_q, layers n_out = n_out or n_q dev = qml.device("default.qubit", wires=n_q) @qml.qnode(dev, interface="torch", diff_method="backprop") def circ(inputs, w): for i in range(n_q): qml.RX(inputs[..., i], wires=i) for L in range(layers): for i in range(n_q): qml.RY(w[L, i], wires=i) for i in range(n_q-1): qml.CNOT(wires=[i, i+1]) if n_q > 2: qml.CNOT(wires=[n_q-1, 0]) return [qml.expval(qml.PauliZ(i)) for i in range(n_out)] self.qlayer = qml.qnn.TorchLayer(circ, {"w": (layers, n_q)}) def forward(self, x): return self.qlayer(x) # ═════════════════════════════════════════════════════════════════════ # 3. TT FEED-FORWARD # ═════════════════════════════════════════════════════════════════════ class TTFFN(nn.Module): def __init__(self, D, ff_mult=4, rank=8): super().__init__() E = D * ff_mult self.up = TTLinear(auto_factor(D), auto_factor(E), rank, True) self.down = TTLinear(auto_factor(E), auto_factor(D), rank, True) def forward(self, x): return self.down(F.gelu(self.up(x))) def set_rank(self, r): self.up.set_rank(r); self.down.set_rank(r) # ═════════════════════════════════════════════════════════════════════ # 4. RANK SCHEDULER # ═════════════════════════════════════════════════════════════════════ class RankScheduler(nn.Module): """rank = r_min + alpha * entropy (EMA-smoothed)""" def __init__(self, mn=2, mx=16, a=2.0, sm=0.9): super().__init__() self.mn, self.mx = mn, mx self.alpha = nn.Parameter(torch.tensor(a)) self.sm = sm self.register_buffer('ema', torch.tensor(0.5)) self.register_buffer('cur', torch.tensor(float(mx))) def forward(self, ent): s = ent.mean().detach() if ent.numel()>1 else ent.detach() self.ema = self.sm*self.ema + (1-self.sm)*s raw = self.mn + self.alpha*self.ema r = int(torch.clamp(raw, self.mn, self.mx).round().item()) if self.training: self.cur.fill_(r) return r @property def current(self): return int(self.cur.item()) # ═════════════════════════════════════════════════════════════════════ # 5. QUANTUM ROUTER # ═════════════════════════════════════════════════════════════════════ class QuantumRouter(nn.Module): """Learned gate: routes only hard tokens through quantum circuit.""" def __init__(self, D, qmod, thr=0.5): super().__init__() self.qmod = qmod self.thr = thr self.gate = nn.Sequential( nn.Linear(D, D//4), nn.ReLU(), nn.Linear(D//4,1), nn.Sigmoid()) self.register_buffer('tot', torch.tensor(0.0)) self.register_buffer('qtok', torch.tensor(0.0)) def forward(self, x): B,S,D = x.shape g = self.gate(x.reshape(-1,D)).squeeze(-1).reshape(B,S) m = (g > self.thr).float() if self.training: m = m.detach() + g - g.detach() xf = x.reshape(-1,D); mf = m.reshape(-1) sel = xf[mf > 0.5]; out = xf.clone() if sel.shape[0]>0: qo = self.qmod(sel) if qo.shape[-1]!=D: if not hasattr(self,'_proj'): self._proj = nn.Linear(qo.shape[-1],D).to(x.device) qo = self._proj(qo) out[mf > 0.5] = qo.to(out.dtype) self.tot += B*S; self.qtok += m.sum() return out.reshape(B,S,D), g def sparsity(self): if self.tot>0: return 1.0-(self.qtok/self.tot).item() return 1.0 # ═════════════════════════════════════════════════════════════════════ # 6. ATTENTION # ═════════════════════════════════════════════════════════════════════ class MHA(nn.Module): def __init__(self, D, heads=4, drop=0.1): super().__init__() assert D%heads==0 self.h, self.hd = heads, D//heads self.scale = self.hd**-0.5 self.qkv = nn.Linear(D, 3*D, bias=False) self.out = nn.Linear(D, D) self.drop = nn.Dropout(drop) def forward(self, x, mask=None): B,S,D = x.shape qkv = self.qkv(x).reshape(B,S,3,self.h,self.hd).permute(2,0,3,1,4) q,k,v = qkv[0], qkv[1], qkv[2] a = (q@k.transpose(-2,-1))*self.scale if mask is not None: a = a.masked_fill(mask[:,None,None,:]==0, float('-inf')) aw = F.softmax(a, dim=-1); aw = self.drop(aw) o = (aw@v).transpose(1,2).reshape(B,S,D) return self.out(o), aw # ═════════════════════════════════════════════════════════════════════ # 7. HYBRID BLOCK # ═════════════════════════════════════════════════════════════════════ class HybridBlock(nn.Module): def __init__(self, cfg): super().__init__() D = cfg.d_model self.a_norm = nn.LayerNorm(D) self.attn = MHA(D, cfg.n_heads, cfg.dropout) self.f_norm = nn.LayerNorm(D) self.ffn = TTFFN(D, cfg.ff_mult, cfg.tt_rank) self.qrouter = None if cfg.q_qubits: qc = QuantumEmbed(cfg.q_qubits, cfg.q_layers, cfg.q_qubits) qw = nn.Sequential(nn.Linear(D, cfg.q_qubits), qc) self.qrouter = QuantumRouter(D, qw) self.rs = RankScheduler(cfg.min_rank, cfg.tt_rank, cfg.rank_alpha, cfg.rank_smoothing) self.drop = nn.Dropout(cfg.dropout) def forward(self, x, mask=None, adapt=True): ao, aw = self.attn(self.a_norm(x), mask) x = x + self.drop(ao) eps=1e-8 ent = -torch.sum(aw*torch.log(aw+eps), dim=-1).mean(dim=-1).mean() tr = self.rs(ent) if adapt else self.rs.mx if adapt: self.ffn.set_rank(tr) n = self.f_norm(x) qs = 1.0 if self.qrouter is not None: qo, _ = self.qrouter(n) n = n + self.drop(qo - n.detach() + n) qs = self.qrouter.sparsity() x = x + self.drop(self.ffn(n)) return {'out':x, 'aw':aw, 'entropy':ent, 'rank':tr, 'qsparse':qs} # ═════════════════════════════════════════════════════════════════════ # 8. Q-TENSORFORMER MODEL # ═════════════════════════════════════════════════════════════════════ class QTensorFormer(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.tok = nn.Embedding(cfg.vocab, cfg.d_model) self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02) self.layers = nn.ModuleList([HybridBlock(cfg) for _ in range(cfg.n_layers)]) self.norm = nn.LayerNorm(cfg.d_model) self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False) self.head.weight = self.tok.weight self._init() def _init(self): for p in self.parameters(): if p.dim()>=2: nn.init.xavier_uniform_(p) def forward(self, ids, mask=None, adapt=True): B,S = ids.shape x = self.tok(ids) + self.pos[:,:S,:] if mask is not None: mask = mask[:,None,None,:] bos = [] for l in self.layers: o = l(x, mask, adapt); x=o['out']; bos.append(o) x = self.norm(x); logits = self.head(x) ent = torch.stack([b['entropy'] for b in bos]).mean() rk = sum(b['rank'] for b in bos)/len(bos) qs = sum(b['qsparse'] for b in bos)/len(bos) return {'logits':logits,'entropy':ent,'rank':rk,'qsparse':qs} def loss(self, ids, mask=None, labels=None): if labels is None: labels=ids.clone() out = self(ids, mask) sl = out['logits'][:,:-1].contiguous() ll = labels[:,1:].contiguous() l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100) return {'loss':l,'ppl':torch.exp(l),'entropy':out['entropy'],'rank':out['rank'],'qsparse':out['qsparse']} def nparams(self): t = sum(p.numel() for p in self.parameters()) tr = sum(p.numel() for p in self.parameters() if p.requires_grad) return {'total':t,'trainable':tr} # ═════════════════════════════════════════════════════════════════════ # 9. BASELINE (identical architecture, dense FFN) # ═════════════════════════════════════════════════════════════════════ class Baseline(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.tok = nn.Embedding(cfg.vocab, cfg.d_model) self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02) self.drop = nn.Dropout(cfg.dropout) self.layers = nn.ModuleList() for _ in range(cfg.n_layers): self.layers.append(nn.ModuleDict({ 'a_n': nn.LayerNorm(cfg.d_model), 'a': MHA(cfg.d_model, cfg.n_heads, cfg.dropout), 'f_n': nn.LayerNorm(cfg.d_model), 'ff': nn.Sequential( nn.Linear(cfg.d_model, cfg.d_model*cfg.ff_mult), nn.GELU(), nn.Dropout(cfg.dropout), nn.Linear(cfg.d_model*cfg.ff_mult, cfg.d_model)), })) self.norm = nn.LayerNorm(cfg.d_model) self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False) self.head.weight = self.tok.weight self._init() def _init(self): for p in self.parameters(): if p.dim()>=2: nn.init.xavier_uniform_(p) def forward(self, ids, mask=None): B,S = ids.shape x = self.tok(ids)+self.pos[:,:S,:]; x=self.drop(x) m = mask[:,None,None,:] if mask is not None else None for l in self.layers: ao,_ = l['a'](l['a_n'](x),m); x=x+self.drop(ao) x = x+self.drop(l['ff'](l['f_n'](x))) return {'logits':self.head(self.norm(x))} def loss(self, ids, mask=None, labels=None): if labels is None: labels=ids.clone() out = self(ids, mask) sl = out['logits'][:,:-1].contiguous() ll = labels[:,1:].contiguous() l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100) return {'loss':l,'ppl':torch.exp(l)} def nparams(self): t = sum(p.numel() for p in self.parameters()) tr = sum(p.numel() for p in self.parameters() if p.requires_grad) return {'total':t,'trainable':tr} # ═════════════════════════════════════════════════════════════════════ # 10. TRAINING UTILITIES # ═════════════════════════════════════════════════════════════════════ def make_data(vocab=1000, seq=64, n=500, bs=16): d = torch.randint(1, vocab, (n, seq)) ds = torch.utils.data.TensorDataset(d) return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, collate_fn=lambda batch: {'input_ids': torch.stack([item[0] for item in batch])}) def train_epoch(model, dl, opt, sched, e, tag="M"): model.train(); tl,tp,nb = 0.0,0.0,0; ex={} for b in dl: ids = b['input_ids']; m = b.get('attention_mask') opt.zero_grad() out = model.loss(ids, m); out['loss'].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if sched: sched.step() tl += out['loss'].item(); tp += out['ppl'].item(); nb += 1 for k in ['entropy','rank','qsparse']: if k in out: ex[k]=ex.get(k,0.0)+(out[k].item() if isinstance(out[k],torch.Tensor) else out[k]) al,ap = tl/max(nb,1), tp/max(nb,1) s = f"[{tag}] E{e:2d} loss={al:.4f} ppl={ap:6.1f}" for k,v in ex.items(): s+=f" {k}={v/max(nb,1):.3f}" print(s); return al,ap @torch.no_grad() def evaluate(model, dl): model.eval(); tl,tp,nb=0.0,0.0,0 for b in dl: ids=b['input_ids']; m=b.get('attention_mask') out=model.loss(ids,m); tl+=out['loss'].item(); tp+=out['ppl'].item(); nb+=1 return tl/max(nb,1), tp/max(nb,1) # ═════════════════════════════════════════════════════════════════════ # 11. MAIN BENCHMARK # ═════════════════════════════════════════════════════════════════════ def main(): torch.manual_seed(42) cfg = CFG(d_model=64, n_layers=2, n_heads=4, tt_rank=8, q_qubits=4, q_sparsity=0.3, vocab=1000, max_seq=64) print(f"Config: d={cfg.d_model} layers={cfg.n_layers} heads={cfg.n_heads} rank={cfg.tt_rank}") print(f"Quantum: qubits={cfg.q_qubits} sparsity={cfg.q_sparsity}") print(f"Tensor FFN: ON\n") qt = QTensorFormer(cfg) bl = Baseline(cfg) pq = qt.nparams(); pb = bl.nparams() print(f"Q-TensorFormer params: {pq['trainable']:>10,}") print(f"Baseline params: {pb['trainable']:>10,}") print(f"Compression ratio: {pb['trainable']/max(pq['trainable'],1):>10.1f}x\n") train_dl = make_data(cfg.vocab, cfg.max_seq, 500, 16) val_dl = make_data(cfg.vocab, cfg.max_seq, 100, 16) E = 8 print("=" * 50) print(" TRAINING Q-TENSORFORMER") print("=" * 50) oq = torch.optim.AdamW(qt.parameters(), lr=cfg.lr) sq = torch.optim.lr_scheduler.CosineAnnealingLR(oq, E*len(train_dl)) for e in range(1, E+1): train_epoch(qt, train_dl, oq, sq, e, "Q-TF") print("\n" + "=" * 50) print(" TRAINING BASELINE") print("=" * 50) ob = torch.optim.AdamW(bl.parameters(), lr=cfg.lr) sb = torch.optim.lr_scheduler.CosineAnnealingLR(ob, E*len(train_dl)) for e in range(1, E+1): train_epoch(bl, train_dl, ob, sb, e, "BSL") ql,qp = evaluate(qt, val_dl) bl_val,bp = evaluate(bl, val_dl) torch.save(qt.state_dict(), '/tmp/qt.pt') torch.save(bl.state_dict(), '/tmp/bl.pt') qsz = os.path.getsize('/tmp/qt.pt')/(1024*1024) bsz = os.path.getsize('/tmp/bl.pt')/(1024*1024) print("\n" + "=" * 65) print(" RESULTS") print("=" * 65) print(f"{'Metric':<30} {'Q-TensorFormer':>15} {'Baseline':>15}") print("-" * 60) print(f"{'Parameters':<30} {pq['trainable']:>13,} {pb['trainable']:>13,}") print(f"{'Val Loss':<30} {ql:>15.4f} {bl_val:>15.4f}") print(f"{'Val Perplexity':<30} {qp:>15.2f} {bp:>15.2f}") print(f"{'Model Size (MB)':<30} {qsz:>15.1f} {bsz:>15.1f}") ps = (1-pq['trainable']/pb['trainable'])*100 ss = (1-qsz/bsz)*100 pr = qp/bp print(f"\nParameter reduction: {ps:.1f}%") print(f"Size reduction: {ss:.1f}%") print(f"PPL ratio (Q-TF/BL): {pr:.2f}x") if pr < 1.1: print(f"\n >> VERDICT: Significant compression with minimal quality loss! <<") elif pr < 1.3: print(f"\n >> VERDICT: Moderate trade-off — compression worth the cost <<") else: print(f"\n >> VERDICT: Quality gap too large, needs tuning <<") print("\nDone!") return {'params_q':pq['trainable'],'params_b':pb['trainable'],'qloss':ql,'qppl':qp,'bloss':bl_val,'bppl':bp,'qsz':qsz,'bsz':bsz,'comp':ps,'sred':ss,'ppl_ratio':pr} if __name__ == '__main__': results = main()