| """ |
| Baseline implementations for fair comparison. |
| |
| Baselines: |
| 1. Standard Transformer: Dense MLP FFN, no TT, no quantum. |
| 2. Distilled: Smaller transformer trained with KD. |
| 3. Pruned: Magnitude-based structured pruning. |
| 4. TT-Only: Tensor network FFN without quantum or adaptive rank. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Optional |
|
|
|
|
| class StandardTransformer(nn.Module): |
| """ |
| Basic transformer decoder (GPT-style) with dense MLP FFN. |
| |
| Reference baseline — matches Q-TensorFormer architecture |
| exactly except for TT decomposition and quantum layers. |
| """ |
|
|
| def __init__(self, vocab_size: int = 10000, d_model: int = 128, |
| n_heads: int = 4, n_layers: int = 2, ff_mult: int = 4, |
| max_seq_len: int = 128, dropout: float = 0.1): |
| super().__init__() |
| self.d_model = d_model |
| self.config = type("config", (), { |
| "d_model": d_model, "n_heads": n_heads, "n_layers": n_layers, |
| "ff_multiplier": ff_mult, "max_seq_len": max_seq_len, |
| "vocab_size": vocab_size, "dropout": dropout, |
| })() |
|
|
| self.embedding = nn.Embedding(vocab_size, d_model) |
| self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout) |
|
|
| self.blocks = nn.ModuleList([ |
| _StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len) |
| for _ in range(n_layers) |
| ]) |
|
|
| self.ln_f = nn.LayerNorm(d_model) |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
| self.lm_head.weight = self.embedding.weight |
|
|
| def forward(self, input_ids, attention_mask=None, return_stats=False): |
| x = self.embedding(input_ids) |
| x = self.pos_encoding(x) |
|
|
| for block in self.blocks: |
| x = block(x, mask=attention_mask) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| if return_stats: |
| return logits, [] |
| return logits |
|
|
| @property |
| def total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| class DistilledTransformer(nn.Module): |
| """ |
| Smaller transformer trained via knowledge distillation. |
| |
| Designed to match Q-TensorFormer parameter counts. |
| """ |
|
|
| def __init__(self, vocab_size: int = 10000, d_model: int = 96, |
| n_heads: int = 4, n_layers: int = 2, ff_mult: int = 3, |
| max_seq_len: int = 128, dropout: float = 0.1): |
| super().__init__() |
| self.d_model = d_model |
| self.config = type("config", (), { |
| "d_model": d_model, "n_heads": n_heads, "n_layers": n_layers, |
| "ff_multiplier": ff_mult, "max_seq_len": max_seq_len, |
| "vocab_size": vocab_size, "dropout": dropout, |
| })() |
|
|
| self.embedding = nn.Embedding(vocab_size, d_model) |
| self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout) |
|
|
| self.blocks = nn.ModuleList([ |
| _StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len) |
| for _ in range(n_layers) |
| ]) |
|
|
| self.ln_f = nn.LayerNorm(d_model) |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
| self.lm_head.weight = self.embedding.weight |
|
|
| def forward(self, input_ids, attention_mask=None, return_stats=False): |
| x = self.embedding(input_ids) |
| x = self.pos_encoding(x) |
|
|
| for block in self.blocks: |
| x = block(x, mask=attention_mask) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| if return_stats: |
| return logits, [] |
| return logits |
|
|
| @property |
| def total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| class PrunedTransformer(nn.Module): |
| """ |
| Magnitude-pruned standard transformer. |
| |
| Prunes FFN weights globally to match Q-TensorFormer parameter count. |
| Applies structured pruning (zeroing channels) for efficiency. |
| """ |
|
|
| def __init__(self, base_model: StandardTransformer, |
| prune_ratio: float = 0.5): |
| super().__init__() |
| self.base = base_model |
| self.prune_ratio = prune_ratio |
| self.config = base_model.config |
| self._prune() |
|
|
| def _prune(self): |
| """Apply structured magnitude pruning to FFN layers.""" |
| all_weights = [] |
| for block in self.base.blocks: |
| for weight in [block.ffn[0].weight, block.ffn[2].weight]: |
| all_weights.append(weight.flatten()) |
|
|
| |
| flat = torch.cat(all_weights) |
| k = int(len(flat) * self.prune_ratio) |
| threshold = torch.topk(flat.abs(), k, largest=False).values[-1] |
|
|
| |
| for block in self.base.blocks: |
| for layer in [block.ffn[0], block.ffn[2]]: |
| mask = (layer.weight.abs() > threshold).float() |
| |
| row_norms = mask.sum(dim=1) |
| dead_rows = row_norms < layer.weight.size(1) * 0.1 |
| mask[dead_rows] = 0 |
| layer.weight.data *= mask |
|
|
| def forward(self, *args, **kwargs): |
| return self.base(*args, **kwargs) |
|
|
| @property |
| def total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| class _StandardBlock(nn.Module): |
| """Standard transformer decoder block.""" |
|
|
| def __init__(self, d_model, n_heads, ff_mult, dropout, max_seq_len): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(d_model) |
| self.attn = _CausalAttention(d_model, n_heads, dropout, max_seq_len) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| nn.Linear(d_model, d_model * ff_mult), |
| nn.GELU(), |
| nn.Linear(d_model * ff_mult, d_model), |
| nn.Dropout(dropout), |
| ) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x, mask=None): |
| x = x + self.dropout(self.attn(self.ln1(x), mask=mask)) |
| x = x + self.ffn(self.ln2(x)) |
| return x |
|
|
|
|
| class _CausalAttention(nn.Module): |
| """Causal multi-head attention.""" |
|
|
| def __init__(self, d_model, n_heads, dropout, max_seq_len): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.scale = math.sqrt(self.head_dim) |
|
|
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| self.max_seq_len = max_seq_len |
|
|
| def forward(self, x, mask=None): |
| B, T, C = x.shape |
| qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.unbind(dim=2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
| attn = (q @ k.transpose(-2, -1)) / self.scale |
| causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1) |
| attn = attn + causal |
|
|
| if mask is not None: |
| attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf") |
|
|
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
|
|
| out = (attn @ v).transpose(1, 2).reshape(B, T, C) |
| return self.out_proj(out) |
|
|
|
|
| class _PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len, dropout): |
| super().__init__() |
| self.dropout = nn.Dropout(dropout) |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(max_len).unsqueeze(1).float() |
| div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer("pe", pe.unsqueeze(0)) |
|
|
| def forward(self, x): |
| return self.dropout(x + self.pe[:, :x.size(1)]) |
|
|