Q-TensorFormer / src /models.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
raw
history blame
10.2 kB
"""
Q-TensorFormer v3: Complete Model Architectures.
Model variants:
- QTensorFormer: Full hybrid model (TT-FFN + quantum + adaptive rank)
- TensorBaseline: TT-FFN only (no quantum, fixed rank)
- DenseBaseline: Standard transformer (no TT, no quantum)
- DistilledVariants: Knowledge-distilled compact models
"""
import torch
import torch.nn as nn
import math
from typing import Optional, Dict, List
from .blocks import HybridBlock
from .config import ModelConfig
class PositionalEncoding(nn.Module):
"""Fixed sinusoidal positional encoding."""
def __init__(self, d_model: int, max_len: int = 128, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32) *
(-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(x + self.pe[:, :x.size(1), :])
class QTensorFormer(nn.Module):
"""
Quantum-Enhanced Tensor Network Transformer.
Full hybrid model: replaces FFN with TT decomposition and adds
quantum feature routing with adaptive rank scheduling.
Parameters
----------
config : ModelConfig
Model configuration.
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Embeddings
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.pos_encoding = PositionalEncoding(
config.d_model, config.max_seq_len, config.dropout
)
# Transformer blocks
self.blocks = nn.ModuleList([
HybridBlock(
d_model=config.d_model,
n_heads=config.n_heads,
ff_multiplier=config.ff_multiplier,
tt_rank=config.tt_rank,
tt_min_rank=config.tt_min_rank,
use_quantum=config.use_quantum,
n_qubits=config.n_qubits,
n_quantum_layers=config.n_quantum_layers,
quantum_sparsity=config.quantum_sparsity,
rank_alpha=config.rank_alpha,
rank_smoothing=config.rank_smoothing,
dropout=config.dropout,
max_seq_len=config.max_seq_len,
)
for _ in range(config.n_layers)
])
# Output
self.ln_f = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying: embedding matrix = LM head
self.lm_head.weight = self.embedding.weight
self._post_init()
def _post_init(self):
"""Initialize weights."""
for name, param in self.named_parameters():
if "weight" in name and param.dim() >= 2:
nn.init.xavier_uniform_(param)
elif "bias" in name:
nn.init.zeros_(param)
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_stats: bool = False):
"""
Args:
input_ids: (batch, seq_len) token indices
attention_mask: (batch, seq_len) optional padding mask
return_stats: return per-block statistics
Returns:
logits: (batch, seq_len, vocab_size)
stats: list of per-block stats dicts (if return_stats=True)
"""
x = self.embedding(input_ids)
x = self.pos_encoding(x)
all_stats = []
for block in self.blocks:
x, stats = block(x, mask=attention_mask)
all_stats.append(stats)
x = self.ln_f(x)
logits = self.lm_head(x)
if return_stats:
return logits, all_stats
return logits
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20,
temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
"""Simple autoregressive generation."""
self.eval()
for _ in range(max_new_tokens):
if input_ids.size(1) > self.config.max_seq_len:
input_ids = input_ids[:, -self.config.max_seq_len:]
logits = self(input_ids)
logits = logits[:, -1, :] / temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
return input_ids
def reset_schedulers(self):
"""Reset all rank schedulers and quantum routers."""
for block in self.blocks:
block.reset_scheduler()
@property
def stats(self) -> Dict:
"""Runtime statistics across all blocks."""
stats = {
"total_params": self.total_params,
"tt_params": self.tt_params,
"compression_ratio": self.compression_ratio,
"rank_history": {},
"quantum_usage": {},
}
for i, block in enumerate(self.blocks):
stats["rank_history"][i] = block.rank_scheduler.current_rank
if block.quantum_router is not None:
stats["quantum_usage"][i] = block.quantum_router.usage_percent
return stats
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
@property
def trainable_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@property
def tt_params(self) -> int:
"""Count only TT-decomposed parameters."""
count = 0
for block in self.blocks:
for core in block.tt_ffn.up_proj.cores:
count += core.numel()
for core in block.tt_ffn.down_proj.cores:
count += core.numel()
return count
@property
def compression_ratio(self) -> float:
"""Estimated compression ratio vs. dense equivalent."""
dense_per_block = 2 * self.config.d_model * self.config.d_model * self.config.ff_multiplier
base = self.total_params - self.tt_params
tt = self.tt_params
return (base + dense_per_block * self.config.n_layers) / max(base + tt, 1)
def flops_estimate(self, batch_size: int = 1, seq_len: int = None) -> Dict:
"""Estimate total FLOPs."""
T = seq_len or self.config.max_seq_len
total = 0
breakdown = {}
for i, block in enumerate(self.blocks):
b = block.flops_estimate(batch_size, T)
total += b["total"]
breakdown[f"block_{i}"] = b
return {"total": total, "breakdown": breakdown}
class DenseBaseline(nn.Module):
"""
Standard transformer baseline — no TT, no quantum.
Same hyperparameters as QTensorFormer for fair comparison.
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
self.pos_encoding = PositionalEncoding(
config.d_model, config.max_seq_len, config.dropout
)
self.blocks = nn.ModuleList([
nn.ModuleDict({
"ln1": nn.LayerNorm(config.d_model),
"attn": nn.MultiheadAttention(
config.d_model, config.n_heads,
dropout=config.dropout, batch_first=True
),
"ln2": nn.LayerNorm(config.d_model),
"ffn": nn.Sequential(
nn.Linear(config.d_model, config.d_model * config.ff_multiplier),
nn.GELU(),
nn.Linear(config.d_model * config.ff_multiplier, config.d_model),
),
"dropout": nn.Dropout(config.dropout),
})
for _ in range(config.n_layers)
])
self.ln_f = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight
def forward(self, input_ids, attention_mask=None, return_stats=False):
x = self.embedding(input_ids)
x = self.pos_encoding(x)
for block in self.blocks:
attn_out, _ = block["attn"](
block["ln1"](x), block["ln1"](x), block["ln1"](x),
key_padding_mask=attention_mask, need_weights=False
)
x = x + block["dropout"](attn_out)
ffn_out = block["ffn"](block["ln2"](x))
x = x + block["dropout"](ffn_out)
x = self.ln_f(x)
logits = self.lm_head(x)
if return_stats:
return logits, []
return logits
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def create_model(config: ModelConfig, model_type: str = "qtensor") -> nn.Module:
"""
Factory for model creation.
Args:
config: ModelConfig instance.
model_type: 'qtensor', 'tensor_only' (no quantum), 'dense' (baseline),
'distilled' (knowledge-distilled compact).
Returns:
nn.Module instance.
"""
if model_type == "qtensor":
config.use_quantum = True
return QTensorFormer(config)
elif model_type == "tensor_only":
config.use_quantum = False
return QTensorFormer(config)
elif model_type == "dense":
return DenseBaseline(config)
elif model_type == "distilled":
config.use_quantum = True
config.tt_rank = max(2, config.tt_rank // 2) # More aggressively compressed
return QTensorFormer(config)
else:
raise ValueError(f"Unknown model_type: {model_type}")