| """ |
| Q-TensorFormer v3: Complete Model Architectures. |
| |
| Model variants: |
| - QTensorFormer: Full hybrid model (TT-FFN + quantum + adaptive rank) |
| - TensorBaseline: TT-FFN only (no quantum, fixed rank) |
| - DenseBaseline: Standard transformer (no TT, no quantum) |
| - DistilledVariants: Knowledge-distilled compact models |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import math |
| from typing import Optional, Dict, List |
|
|
| from .blocks import HybridBlock |
| from .config import ModelConfig |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| """Fixed sinusoidal positional encoding.""" |
|
|
| def __init__(self, d_model: int, max_len: int = 128, dropout: float = 0.1): |
| super().__init__() |
| self.dropout = nn.Dropout(dropout) |
|
|
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2, dtype=torch.float32) * |
| (-math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| self.register_buffer("pe", pe.unsqueeze(0)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.dropout(x + self.pe[:, :x.size(1), :]) |
|
|
|
|
| class QTensorFormer(nn.Module): |
| """ |
| Quantum-Enhanced Tensor Network Transformer. |
| |
| Full hybrid model: replaces FFN with TT decomposition and adds |
| quantum feature routing with adaptive rank scheduling. |
| |
| Parameters |
| ---------- |
| config : ModelConfig |
| Model configuration. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
| self.pos_encoding = PositionalEncoding( |
| config.d_model, config.max_seq_len, config.dropout |
| ) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| HybridBlock( |
| d_model=config.d_model, |
| n_heads=config.n_heads, |
| ff_multiplier=config.ff_multiplier, |
| tt_rank=config.tt_rank, |
| tt_min_rank=config.tt_min_rank, |
| use_quantum=config.use_quantum, |
| n_qubits=config.n_qubits, |
| n_quantum_layers=config.n_quantum_layers, |
| quantum_sparsity=config.quantum_sparsity, |
| rank_alpha=config.rank_alpha, |
| rank_smoothing=config.rank_smoothing, |
| dropout=config.dropout, |
| max_seq_len=config.max_seq_len, |
| ) |
| for _ in range(config.n_layers) |
| ]) |
|
|
| |
| self.ln_f = nn.LayerNorm(config.d_model) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| |
| self.lm_head.weight = self.embedding.weight |
|
|
| self._post_init() |
|
|
| def _post_init(self): |
| """Initialize weights.""" |
| for name, param in self.named_parameters(): |
| if "weight" in name and param.dim() >= 2: |
| nn.init.xavier_uniform_(param) |
| elif "bias" in name: |
| nn.init.zeros_(param) |
|
|
| def forward(self, input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_stats: bool = False): |
| """ |
| Args: |
| input_ids: (batch, seq_len) token indices |
| attention_mask: (batch, seq_len) optional padding mask |
| return_stats: return per-block statistics |
| |
| Returns: |
| logits: (batch, seq_len, vocab_size) |
| stats: list of per-block stats dicts (if return_stats=True) |
| """ |
| x = self.embedding(input_ids) |
| x = self.pos_encoding(x) |
|
|
| all_stats = [] |
| for block in self.blocks: |
| x, stats = block(x, mask=attention_mask) |
| all_stats.append(stats) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| if return_stats: |
| return logits, all_stats |
| return logits |
|
|
| @torch.no_grad() |
| def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20, |
| temperature: float = 1.0, top_k: int = 50) -> torch.Tensor: |
| """Simple autoregressive generation.""" |
| self.eval() |
| for _ in range(max_new_tokens): |
| if input_ids.size(1) > self.config.max_seq_len: |
| input_ids = input_ids[:, -self.config.max_seq_len:] |
| logits = self(input_ids) |
| logits = logits[:, -1, :] / temperature |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
| probs = torch.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, 1) |
| input_ids = torch.cat([input_ids, next_token], dim=-1) |
| return input_ids |
|
|
| def reset_schedulers(self): |
| """Reset all rank schedulers and quantum routers.""" |
| for block in self.blocks: |
| block.reset_scheduler() |
|
|
| @property |
| def stats(self) -> Dict: |
| """Runtime statistics across all blocks.""" |
| stats = { |
| "total_params": self.total_params, |
| "tt_params": self.tt_params, |
| "compression_ratio": self.compression_ratio, |
| "rank_history": {}, |
| "quantum_usage": {}, |
| } |
| for i, block in enumerate(self.blocks): |
| stats["rank_history"][i] = block.rank_scheduler.current_rank |
| if block.quantum_router is not None: |
| stats["quantum_usage"][i] = block.quantum_router.usage_percent |
| return stats |
|
|
| @property |
| def total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
| @property |
| def trainable_params(self) -> int: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| @property |
| def tt_params(self) -> int: |
| """Count only TT-decomposed parameters.""" |
| count = 0 |
| for block in self.blocks: |
| for core in block.tt_ffn.up_proj.cores: |
| count += core.numel() |
| for core in block.tt_ffn.down_proj.cores: |
| count += core.numel() |
| return count |
|
|
| @property |
| def compression_ratio(self) -> float: |
| """Estimated compression ratio vs. dense equivalent.""" |
| dense_per_block = 2 * self.config.d_model * self.config.d_model * self.config.ff_multiplier |
| base = self.total_params - self.tt_params |
| tt = self.tt_params |
| return (base + dense_per_block * self.config.n_layers) / max(base + tt, 1) |
|
|
| def flops_estimate(self, batch_size: int = 1, seq_len: int = None) -> Dict: |
| """Estimate total FLOPs.""" |
| T = seq_len or self.config.max_seq_len |
| total = 0 |
| breakdown = {} |
| for i, block in enumerate(self.blocks): |
| b = block.flops_estimate(batch_size, T) |
| total += b["total"] |
| breakdown[f"block_{i}"] = b |
| return {"total": total, "breakdown": breakdown} |
|
|
|
|
| class DenseBaseline(nn.Module): |
| """ |
| Standard transformer baseline — no TT, no quantum. |
| |
| Same hyperparameters as QTensorFormer for fair comparison. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.embedding = nn.Embedding(config.vocab_size, config.d_model) |
| self.pos_encoding = PositionalEncoding( |
| config.d_model, config.max_seq_len, config.dropout |
| ) |
|
|
| self.blocks = nn.ModuleList([ |
| nn.ModuleDict({ |
| "ln1": nn.LayerNorm(config.d_model), |
| "attn": nn.MultiheadAttention( |
| config.d_model, config.n_heads, |
| dropout=config.dropout, batch_first=True |
| ), |
| "ln2": nn.LayerNorm(config.d_model), |
| "ffn": nn.Sequential( |
| nn.Linear(config.d_model, config.d_model * config.ff_multiplier), |
| nn.GELU(), |
| nn.Linear(config.d_model * config.ff_multiplier, config.d_model), |
| ), |
| "dropout": nn.Dropout(config.dropout), |
| }) |
| for _ in range(config.n_layers) |
| ]) |
|
|
| self.ln_f = nn.LayerNorm(config.d_model) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.lm_head.weight = self.embedding.weight |
|
|
| def forward(self, input_ids, attention_mask=None, return_stats=False): |
| x = self.embedding(input_ids) |
| x = self.pos_encoding(x) |
|
|
| for block in self.blocks: |
| attn_out, _ = block["attn"]( |
| block["ln1"](x), block["ln1"](x), block["ln1"](x), |
| key_padding_mask=attention_mask, need_weights=False |
| ) |
| x = x + block["dropout"](attn_out) |
|
|
| ffn_out = block["ffn"](block["ln2"](x)) |
| x = x + block["dropout"](ffn_out) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| if return_stats: |
| return logits, [] |
| return logits |
|
|
| @property |
| def total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| def create_model(config: ModelConfig, model_type: str = "qtensor") -> nn.Module: |
| """ |
| Factory for model creation. |
| |
| Args: |
| config: ModelConfig instance. |
| model_type: 'qtensor', 'tensor_only' (no quantum), 'dense' (baseline), |
| 'distilled' (knowledge-distilled compact). |
| |
| Returns: |
| nn.Module instance. |
| """ |
| if model_type == "qtensor": |
| config.use_quantum = True |
| return QTensorFormer(config) |
| elif model_type == "tensor_only": |
| config.use_quantum = False |
| return QTensorFormer(config) |
| elif model_type == "dense": |
| return DenseBaseline(config) |
| elif model_type == "distilled": |
| config.use_quantum = True |
| config.tt_rank = max(2, config.tt_rank // 2) |
| return QTensorFormer(config) |
| else: |
| raise ValueError(f"Unknown model_type: {model_type}") |
|
|