Q-TensorFormer / q_tensor_former.py
Premchan369's picture
Upload q_tensor_former.py
79a43db verified
#!/usr/bin/env python3
"""
Q-TensorFormer: Quantum-Enhanced Tensor Network LLM Compression Engine
=======================================================================
Hybrid quantum-tensor transformer with:
- Pure PyTorch Tensor-Train FFN layers (no compiled deps)
- PennyLane quantum angle encoding with TorchLayer
- Entanglement-guided adaptive rank scheduling
- Selective quantum routing (only "hard" tokens)
- Full benchmark against identical-architecture baseline
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, os
from typing import Optional, Tuple
from dataclasses import dataclass
import pennylane as qml
print("=" * 65)
print(" Q-TENSORFORMER: Quantum-Tensor LLM Compressor")
print("=" * 65)
print(f" PyTorch {torch.__version__} | PennyLane {qml.__version__}")
print()
# ═════════════════════════════════════════════════════════════════════
# CONFIG
# ═════════════════════════════════════════════════════════════════════
@dataclass
class CFG:
d_model: int = 64
n_heads: int = 4
n_layers: int = 2
ff_mult: int = 4
max_seq: int = 64
vocab: int = 1000
tt_rank: int = 8
min_rank: int = 2
q_qubits: int = 4
q_layers: int = 2
q_sparsity: float = 0.3
dropout: float = 0.1
lr: float = 3e-4
rank_alpha: float = 2.0
rank_smoothing: float = 0.9
# ═════════════════════════════════════════════════════════════════════
# 1. PURE PYTORCH TENSOR-TRAIN LINEAR LAYER
# ═════════════════════════════════════════════════════════════════════
def auto_factor(n, max_f=4):
if n <= 1: return (1, 1)
f, r = [], n
for p in [2,2,2,2,2,3,3,5,7]:
while r % p == 0 and len(f) < max_f:
f.append(p); r //= p
if r > 1:
if len(f) < max_f: f.append(r)
else: f[-1] *= r
while len(f) < 2: f.insert(0, 1)
return tuple(f[:max_f])
class TTLinear(nn.Module):
"""Tensor-Train decomposed linear layer. Pure PyTorch, zero compiled deps."""
def __init__(self, in_shape, out_shape, rank=8, bias=True):
super().__init__()
in_shape = tuple(in_shape)
out_shape = tuple(out_shape)
max_d = max(len(in_shape), len(out_shape))
in_shape = (1,) * (max_d - len(in_shape)) + in_shape
out_shape = (1,) * (max_d - len(out_shape)) + out_shape
assert len(in_shape) == len(out_shape)
self.in_shape, self.out_shape = in_shape, out_shape
self.rank, self.ndim = rank, len(in_shape)
self.in_feat = math.prod(in_shape)
self.out_feat = math.prod(out_shape)
self.cores = nn.ParameterList()
for k in range(self.ndim):
rl = 1 if k == 0 else rank
rr = 1 if k == self.ndim - 1 else rank
c = torch.empty(rl, out_shape[k], in_shape[k], rr)
bnd = math.sqrt(6.0 / max(1, rl*in_shape[k] + rr*out_shape[k]))
nn.init.uniform_(c, -bnd, bnd)
self.cores.append(c)
self.bias = nn.Parameter(torch.zeros(self.out_feat)) if bias else None
tp = sum(c.numel() for c in self.cores) + (self.bias.numel() if bias else 0)
self.compr = (self.in_feat * self.out_feat) / max(tp, 1)
def forward(self, x):
bs = x.shape[:-1]
B = math.prod(bs)
x = x.reshape(B, self.in_feat)
state = x.reshape(B, *self.in_shape)
for k in range(self.ndim):
core = self.cores[k]
r_k, o_k, i_k, r_kp1 = core.shape
if k == 0:
rest = math.prod(self.in_shape[1:])
s = state.reshape(B, i_k, rest)
cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1)
s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1))
s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1)
state = s.reshape(B, r_kp1, -1)
elif k == self.ndim - 1:
prev_os = math.prod(self.out_shape[:k])
s = state.reshape(B, r_k, prev_os, i_k)
cm = core.squeeze(-1)
s = torch.einsum('brpi,roi->bpo', s, cm)
state = s.reshape(B, prev_os * o_k)
else:
prev_os = math.prod(self.out_shape[:k])
rest_in = math.prod(self.in_shape[k+1:])
s = state.reshape(B, r_k, prev_os * i_k * rest_in)
s = s.reshape(B, r_k, prev_os, i_k, rest_in)
s = torch.einsum('brpix,roiq->bpoqx', s, core)
s = s.permute(0, 3, 1, 2, 4)
state = s.reshape(B, r_kp1, prev_os * o_k * rest_in)
out = state.reshape(B, self.out_feat)
if self.bias is not None: out = out + self.bias
return out.reshape(*bs, self.out_feat)
def set_rank(self, nr):
for i, c in enumerate(self.cores):
s = [slice(None)]*4
if i > 0: s[0] = slice(None, nr)
if i < self.ndim - 1: s[3] = slice(None, nr)
self.cores[i] = nn.Parameter(c[tuple(s)].clone())
# ═════════════════════════════════════════════════════════════════════
# 2. QUANTUM ANGLE EMBEDDING (PennyLane)
# ═════════════════════════════════════════════════════════════════════
class QuantumEmbed(nn.Module):
"""Angle embedding → variational circuit → PauliZ expectations."""
def __init__(self, n_q=4, layers=2, n_out=None):
super().__init__()
self.n_q, self.layers = n_q, layers
n_out = n_out or n_q
dev = qml.device("default.qubit", wires=n_q)
@qml.qnode(dev, interface="torch", diff_method="backprop")
def circ(inputs, w):
for i in range(n_q): qml.RX(inputs[..., i], wires=i)
for L in range(layers):
for i in range(n_q): qml.RY(w[L, i], wires=i)
for i in range(n_q-1): qml.CNOT(wires=[i, i+1])
if n_q > 2: qml.CNOT(wires=[n_q-1, 0])
return [qml.expval(qml.PauliZ(i)) for i in range(n_out)]
self.qlayer = qml.qnn.TorchLayer(circ, {"w": (layers, n_q)})
def forward(self, x): return self.qlayer(x)
# ═════════════════════════════════════════════════════════════════════
# 3. TT FEED-FORWARD
# ═════════════════════════════════════════════════════════════════════
class TTFFN(nn.Module):
def __init__(self, D, ff_mult=4, rank=8):
super().__init__()
E = D * ff_mult
self.up = TTLinear(auto_factor(D), auto_factor(E), rank, True)
self.down = TTLinear(auto_factor(E), auto_factor(D), rank, True)
def forward(self, x): return self.down(F.gelu(self.up(x)))
def set_rank(self, r): self.up.set_rank(r); self.down.set_rank(r)
# ═════════════════════════════════════════════════════════════════════
# 4. RANK SCHEDULER
# ═════════════════════════════════════════════════════════════════════
class RankScheduler(nn.Module):
"""rank = r_min + alpha * entropy (EMA-smoothed)"""
def __init__(self, mn=2, mx=16, a=2.0, sm=0.9):
super().__init__()
self.mn, self.mx = mn, mx
self.alpha = nn.Parameter(torch.tensor(a))
self.sm = sm
self.register_buffer('ema', torch.tensor(0.5))
self.register_buffer('cur', torch.tensor(float(mx)))
def forward(self, ent):
s = ent.mean().detach() if ent.numel()>1 else ent.detach()
self.ema = self.sm*self.ema + (1-self.sm)*s
raw = self.mn + self.alpha*self.ema
r = int(torch.clamp(raw, self.mn, self.mx).round().item())
if self.training: self.cur.fill_(r)
return r
@property
def current(self): return int(self.cur.item())
# ═════════════════════════════════════════════════════════════════════
# 5. QUANTUM ROUTER
# ═════════════════════════════════════════════════════════════════════
class QuantumRouter(nn.Module):
"""Learned gate: routes only hard tokens through quantum circuit."""
def __init__(self, D, qmod, thr=0.5):
super().__init__()
self.qmod = qmod
self.thr = thr
self.gate = nn.Sequential(
nn.Linear(D, D//4), nn.ReLU(), nn.Linear(D//4,1), nn.Sigmoid())
self.register_buffer('tot', torch.tensor(0.0))
self.register_buffer('qtok', torch.tensor(0.0))
def forward(self, x):
B,S,D = x.shape
g = self.gate(x.reshape(-1,D)).squeeze(-1).reshape(B,S)
m = (g > self.thr).float()
if self.training:
m = m.detach() + g - g.detach()
xf = x.reshape(-1,D); mf = m.reshape(-1)
sel = xf[mf > 0.5]; out = xf.clone()
if sel.shape[0]>0:
qo = self.qmod(sel)
if qo.shape[-1]!=D:
if not hasattr(self,'_proj'):
self._proj = nn.Linear(qo.shape[-1],D).to(x.device)
qo = self._proj(qo)
out[mf > 0.5] = qo.to(out.dtype)
self.tot += B*S; self.qtok += m.sum()
return out.reshape(B,S,D), g
def sparsity(self):
if self.tot>0: return 1.0-(self.qtok/self.tot).item()
return 1.0
# ═════════════════════════════════════════════════════════════════════
# 6. ATTENTION
# ═════════════════════════════════════════════════════════════════════
class MHA(nn.Module):
def __init__(self, D, heads=4, drop=0.1):
super().__init__()
assert D%heads==0
self.h, self.hd = heads, D//heads
self.scale = self.hd**-0.5
self.qkv = nn.Linear(D, 3*D, bias=False)
self.out = nn.Linear(D, D)
self.drop = nn.Dropout(drop)
def forward(self, x, mask=None):
B,S,D = x.shape
qkv = self.qkv(x).reshape(B,S,3,self.h,self.hd).permute(2,0,3,1,4)
q,k,v = qkv[0], qkv[1], qkv[2]
a = (q@k.transpose(-2,-1))*self.scale
if mask is not None:
a = a.masked_fill(mask[:,None,None,:]==0, float('-inf'))
aw = F.softmax(a, dim=-1); aw = self.drop(aw)
o = (aw@v).transpose(1,2).reshape(B,S,D)
return self.out(o), aw
# ═════════════════════════════════════════════════════════════════════
# 7. HYBRID BLOCK
# ═════════════════════════════════════════════════════════════════════
class HybridBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
D = cfg.d_model
self.a_norm = nn.LayerNorm(D)
self.attn = MHA(D, cfg.n_heads, cfg.dropout)
self.f_norm = nn.LayerNorm(D)
self.ffn = TTFFN(D, cfg.ff_mult, cfg.tt_rank)
self.qrouter = None
if cfg.q_qubits:
qc = QuantumEmbed(cfg.q_qubits, cfg.q_layers, cfg.q_qubits)
qw = nn.Sequential(nn.Linear(D, cfg.q_qubits), qc)
self.qrouter = QuantumRouter(D, qw)
self.rs = RankScheduler(cfg.min_rank, cfg.tt_rank, cfg.rank_alpha, cfg.rank_smoothing)
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x, mask=None, adapt=True):
ao, aw = self.attn(self.a_norm(x), mask)
x = x + self.drop(ao)
eps=1e-8
ent = -torch.sum(aw*torch.log(aw+eps), dim=-1).mean(dim=-1).mean()
tr = self.rs(ent) if adapt else self.rs.mx
if adapt: self.ffn.set_rank(tr)
n = self.f_norm(x)
qs = 1.0
if self.qrouter is not None:
qo, _ = self.qrouter(n)
n = n + self.drop(qo - n.detach() + n)
qs = self.qrouter.sparsity()
x = x + self.drop(self.ffn(n))
return {'out':x, 'aw':aw, 'entropy':ent, 'rank':tr, 'qsparse':qs}
# ═════════════════════════════════════════════════════════════════════
# 8. Q-TENSORFORMER MODEL
# ═════════════════════════════════════════════════════════════════════
class QTensorFormer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.tok = nn.Embedding(cfg.vocab, cfg.d_model)
self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02)
self.layers = nn.ModuleList([HybridBlock(cfg) for _ in range(cfg.n_layers)])
self.norm = nn.LayerNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False)
self.head.weight = self.tok.weight
self._init()
def _init(self):
for p in self.parameters():
if p.dim()>=2: nn.init.xavier_uniform_(p)
def forward(self, ids, mask=None, adapt=True):
B,S = ids.shape
x = self.tok(ids) + self.pos[:,:S,:]
if mask is not None: mask = mask[:,None,None,:]
bos = []
for l in self.layers:
o = l(x, mask, adapt); x=o['out']; bos.append(o)
x = self.norm(x); logits = self.head(x)
ent = torch.stack([b['entropy'] for b in bos]).mean()
rk = sum(b['rank'] for b in bos)/len(bos)
qs = sum(b['qsparse'] for b in bos)/len(bos)
return {'logits':logits,'entropy':ent,'rank':rk,'qsparse':qs}
def loss(self, ids, mask=None, labels=None):
if labels is None: labels=ids.clone()
out = self(ids, mask)
sl = out['logits'][:,:-1].contiguous()
ll = labels[:,1:].contiguous()
l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100)
return {'loss':l,'ppl':torch.exp(l),'entropy':out['entropy'],'rank':out['rank'],'qsparse':out['qsparse']}
def nparams(self):
t = sum(p.numel() for p in self.parameters())
tr = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {'total':t,'trainable':tr}
# ═════════════════════════════════════════════════════════════════════
# 9. BASELINE (identical architecture, dense FFN)
# ═════════════════════════════════════════════════════════════════════
class Baseline(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.tok = nn.Embedding(cfg.vocab, cfg.d_model)
self.pos = nn.Parameter(torch.randn(1, cfg.max_seq, cfg.d_model)*0.02)
self.drop = nn.Dropout(cfg.dropout)
self.layers = nn.ModuleList()
for _ in range(cfg.n_layers):
self.layers.append(nn.ModuleDict({
'a_n': nn.LayerNorm(cfg.d_model),
'a': MHA(cfg.d_model, cfg.n_heads, cfg.dropout),
'f_n': nn.LayerNorm(cfg.d_model),
'ff': nn.Sequential(
nn.Linear(cfg.d_model, cfg.d_model*cfg.ff_mult),
nn.GELU(), nn.Dropout(cfg.dropout),
nn.Linear(cfg.d_model*cfg.ff_mult, cfg.d_model)),
}))
self.norm = nn.LayerNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab, bias=False)
self.head.weight = self.tok.weight
self._init()
def _init(self):
for p in self.parameters():
if p.dim()>=2: nn.init.xavier_uniform_(p)
def forward(self, ids, mask=None):
B,S = ids.shape
x = self.tok(ids)+self.pos[:,:S,:]; x=self.drop(x)
m = mask[:,None,None,:] if mask is not None else None
for l in self.layers:
ao,_ = l['a'](l['a_n'](x),m); x=x+self.drop(ao)
x = x+self.drop(l['ff'](l['f_n'](x)))
return {'logits':self.head(self.norm(x))}
def loss(self, ids, mask=None, labels=None):
if labels is None: labels=ids.clone()
out = self(ids, mask)
sl = out['logits'][:,:-1].contiguous()
ll = labels[:,1:].contiguous()
l = F.cross_entropy(sl.reshape(-1,self.cfg.vocab), ll.reshape(-1), ignore_index=-100)
return {'loss':l,'ppl':torch.exp(l)}
def nparams(self):
t = sum(p.numel() for p in self.parameters())
tr = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {'total':t,'trainable':tr}
# ═════════════════════════════════════════════════════════════════════
# 10. TRAINING UTILITIES
# ═════════════════════════════════════════════════════════════════════
def make_data(vocab=1000, seq=64, n=500, bs=16):
d = torch.randint(1, vocab, (n, seq))
ds = torch.utils.data.TensorDataset(d)
return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True,
collate_fn=lambda batch: {'input_ids': torch.stack([item[0] for item in batch])})
def train_epoch(model, dl, opt, sched, e, tag="M"):
model.train(); tl,tp,nb = 0.0,0.0,0; ex={}
for b in dl:
ids = b['input_ids']; m = b.get('attention_mask')
opt.zero_grad()
out = model.loss(ids, m); out['loss'].backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if sched: sched.step()
tl += out['loss'].item(); tp += out['ppl'].item(); nb += 1
for k in ['entropy','rank','qsparse']:
if k in out: ex[k]=ex.get(k,0.0)+(out[k].item() if isinstance(out[k],torch.Tensor) else out[k])
al,ap = tl/max(nb,1), tp/max(nb,1)
s = f"[{tag}] E{e:2d} loss={al:.4f} ppl={ap:6.1f}"
for k,v in ex.items(): s+=f" {k}={v/max(nb,1):.3f}"
print(s); return al,ap
@torch.no_grad()
def evaluate(model, dl):
model.eval(); tl,tp,nb=0.0,0.0,0
for b in dl:
ids=b['input_ids']; m=b.get('attention_mask')
out=model.loss(ids,m); tl+=out['loss'].item(); tp+=out['ppl'].item(); nb+=1
return tl/max(nb,1), tp/max(nb,1)
# ═════════════════════════════════════════════════════════════════════
# 11. MAIN BENCHMARK
# ═════════════════════════════════════════════════════════════════════
def main():
torch.manual_seed(42)
cfg = CFG(d_model=64, n_layers=2, n_heads=4, tt_rank=8,
q_qubits=4, q_sparsity=0.3, vocab=1000, max_seq=64)
print(f"Config: d={cfg.d_model} layers={cfg.n_layers} heads={cfg.n_heads} rank={cfg.tt_rank}")
print(f"Quantum: qubits={cfg.q_qubits} sparsity={cfg.q_sparsity}")
print(f"Tensor FFN: ON\n")
qt = QTensorFormer(cfg)
bl = Baseline(cfg)
pq = qt.nparams(); pb = bl.nparams()
print(f"Q-TensorFormer params: {pq['trainable']:>10,}")
print(f"Baseline params: {pb['trainable']:>10,}")
print(f"Compression ratio: {pb['trainable']/max(pq['trainable'],1):>10.1f}x\n")
train_dl = make_data(cfg.vocab, cfg.max_seq, 500, 16)
val_dl = make_data(cfg.vocab, cfg.max_seq, 100, 16)
E = 8
print("=" * 50)
print(" TRAINING Q-TENSORFORMER")
print("=" * 50)
oq = torch.optim.AdamW(qt.parameters(), lr=cfg.lr)
sq = torch.optim.lr_scheduler.CosineAnnealingLR(oq, E*len(train_dl))
for e in range(1, E+1): train_epoch(qt, train_dl, oq, sq, e, "Q-TF")
print("\n" + "=" * 50)
print(" TRAINING BASELINE")
print("=" * 50)
ob = torch.optim.AdamW(bl.parameters(), lr=cfg.lr)
sb = torch.optim.lr_scheduler.CosineAnnealingLR(ob, E*len(train_dl))
for e in range(1, E+1): train_epoch(bl, train_dl, ob, sb, e, "BSL")
ql,qp = evaluate(qt, val_dl)
bl_val,bp = evaluate(bl, val_dl)
torch.save(qt.state_dict(), '/tmp/qt.pt')
torch.save(bl.state_dict(), '/tmp/bl.pt')
qsz = os.path.getsize('/tmp/qt.pt')/(1024*1024)
bsz = os.path.getsize('/tmp/bl.pt')/(1024*1024)
print("\n" + "=" * 65)
print(" RESULTS")
print("=" * 65)
print(f"{'Metric':<30} {'Q-TensorFormer':>15} {'Baseline':>15}")
print("-" * 60)
print(f"{'Parameters':<30} {pq['trainable']:>13,} {pb['trainable']:>13,}")
print(f"{'Val Loss':<30} {ql:>15.4f} {bl_val:>15.4f}")
print(f"{'Val Perplexity':<30} {qp:>15.2f} {bp:>15.2f}")
print(f"{'Model Size (MB)':<30} {qsz:>15.1f} {bsz:>15.1f}")
ps = (1-pq['trainable']/pb['trainable'])*100
ss = (1-qsz/bsz)*100
pr = qp/bp
print(f"\nParameter reduction: {ps:.1f}%")
print(f"Size reduction: {ss:.1f}%")
print(f"PPL ratio (Q-TF/BL): {pr:.2f}x")
if pr < 1.1:
print(f"\n >> VERDICT: Significant compression with minimal quality loss! <<")
elif pr < 1.3:
print(f"\n >> VERDICT: Moderate trade-off — compression worth the cost <<")
else:
print(f"\n >> VERDICT: Quality gap too large, needs tuning <<")
print("\nDone!")
return {'params_q':pq['trainable'],'params_b':pb['trainable'],'qloss':ql,'qppl':qp,'bloss':bl_val,'bppl':bp,'qsz':qsz,'bsz':bsz,'comp':ps,'sred':ss,'ppl_ratio':pr}
if __name__ == '__main__':
results = main()