Q-TensorFormer / src /baselines.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
raw
history blame
7.87 kB
"""
Baseline implementations for fair comparison.
Baselines:
1. Standard Transformer: Dense MLP FFN, no TT, no quantum.
2. Distilled: Smaller transformer trained with KD.
3. Pruned: Magnitude-based structured pruning.
4. TT-Only: Tensor network FFN without quantum or adaptive rank.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
class StandardTransformer(nn.Module):
"""
Basic transformer decoder (GPT-style) with dense MLP FFN.
Reference baseline — matches Q-TensorFormer architecture
exactly except for TT decomposition and quantum layers.
"""
def __init__(self, vocab_size: int = 10000, d_model: int = 128,
n_heads: int = 4, n_layers: int = 2, ff_mult: int = 4,
max_seq_len: int = 128, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.config = type("config", (), {
"d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
"ff_multiplier": ff_mult, "max_seq_len": max_seq_len,
"vocab_size": vocab_size, "dropout": dropout,
})()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout)
self.blocks = nn.ModuleList([
_StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight
def forward(self, input_ids, attention_mask=None, return_stats=False):
x = self.embedding(input_ids)
x = self.pos_encoding(x)
for block in self.blocks:
x = block(x, mask=attention_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
if return_stats:
return logits, []
return logits
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
class DistilledTransformer(nn.Module):
"""
Smaller transformer trained via knowledge distillation.
Designed to match Q-TensorFormer parameter counts.
"""
def __init__(self, vocab_size: int = 10000, d_model: int = 96,
n_heads: int = 4, n_layers: int = 2, ff_mult: int = 3,
max_seq_len: int = 128, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.config = type("config", (), {
"d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
"ff_multiplier": ff_mult, "max_seq_len": max_seq_len,
"vocab_size": vocab_size, "dropout": dropout,
})()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout)
self.blocks = nn.ModuleList([
_StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len)
for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight
def forward(self, input_ids, attention_mask=None, return_stats=False):
x = self.embedding(input_ids)
x = self.pos_encoding(x)
for block in self.blocks:
x = block(x, mask=attention_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
if return_stats:
return logits, []
return logits
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
class PrunedTransformer(nn.Module):
"""
Magnitude-pruned standard transformer.
Prunes FFN weights globally to match Q-TensorFormer parameter count.
Applies structured pruning (zeroing channels) for efficiency.
"""
def __init__(self, base_model: StandardTransformer,
prune_ratio: float = 0.5):
super().__init__()
self.base = base_model
self.prune_ratio = prune_ratio
self.config = base_model.config
self._prune()
def _prune(self):
"""Apply structured magnitude pruning to FFN layers."""
all_weights = []
for block in self.base.blocks:
for weight in [block.ffn[0].weight, block.ffn[2].weight]:
all_weights.append(weight.flatten())
# Compute global threshold
flat = torch.cat(all_weights)
k = int(len(flat) * self.prune_ratio)
threshold = torch.topk(flat.abs(), k, largest=False).values[-1]
# Apply structured pruning (zero rows/cols)
for block in self.base.blocks:
for layer in [block.ffn[0], block.ffn[2]]:
mask = (layer.weight.abs() > threshold).float()
# Zero small rows entirely
row_norms = mask.sum(dim=1)
dead_rows = row_norms < layer.weight.size(1) * 0.1
mask[dead_rows] = 0
layer.weight.data *= mask
def forward(self, *args, **kwargs):
return self.base(*args, **kwargs)
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
class _StandardBlock(nn.Module):
"""Standard transformer decoder block."""
def __init__(self, d_model, n_heads, ff_mult, dropout, max_seq_len):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = _CausalAttention(d_model, n_heads, dropout, max_seq_len)
self.ln2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_model * ff_mult),
nn.GELU(),
nn.Linear(d_model * ff_mult, d_model),
nn.Dropout(dropout),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = x + self.dropout(self.attn(self.ln1(x), mask=mask))
x = x + self.ffn(self.ln2(x))
return x
class _CausalAttention(nn.Module):
"""Causal multi-head attention."""
def __init__(self, d_model, n_heads, dropout, max_seq_len):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.scale = math.sqrt(self.head_dim)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.max_seq_len = max_seq_len
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / self.scale
causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1)
attn = attn + causal
if mask is not None:
attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf")
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = (attn @ v).transpose(1, 2).reshape(B, T, C)
return self.out_proj(out)
class _PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
pos = torch.arange(max_len).unsqueeze(1).float()
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x):
return self.dropout(x + self.pe[:, :x.size(1)])