""" Q-TensorFormer v3: Complete Model Architectures. Model variants: - QTensorFormer: Full hybrid model (TT-FFN + quantum + adaptive rank) - TensorBaseline: TT-FFN only (no quantum, fixed rank) - DenseBaseline: Standard transformer (no TT, no quantum) - DistilledVariants: Knowledge-distilled compact models """ import torch import torch.nn as nn import math from typing import Optional, Dict, List from .blocks import HybridBlock from .config import ModelConfig class PositionalEncoding(nn.Module): """Fixed sinusoidal positional encoding.""" def __init__(self, d_model: int, max_len: int = 128, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe.unsqueeze(0)) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(x + self.pe[:, :x.size(1), :]) class QTensorFormer(nn.Module): """ Quantum-Enhanced Tensor Network Transformer. Full hybrid model: replaces FFN with TT decomposition and adds quantum feature routing with adaptive rank scheduling. Parameters ---------- config : ModelConfig Model configuration. """ def __init__(self, config: ModelConfig): super().__init__() self.config = config # Embeddings self.embedding = nn.Embedding(config.vocab_size, config.d_model) self.pos_encoding = PositionalEncoding( config.d_model, config.max_seq_len, config.dropout ) # Transformer blocks self.blocks = nn.ModuleList([ HybridBlock( d_model=config.d_model, n_heads=config.n_heads, ff_multiplier=config.ff_multiplier, tt_rank=config.tt_rank, tt_min_rank=config.tt_min_rank, use_quantum=config.use_quantum, n_qubits=config.n_qubits, n_quantum_layers=config.n_quantum_layers, quantum_sparsity=config.quantum_sparsity, rank_alpha=config.rank_alpha, rank_smoothing=config.rank_smoothing, dropout=config.dropout, max_seq_len=config.max_seq_len, ) for _ in range(config.n_layers) ]) # Output self.ln_f = nn.LayerNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Weight tying: embedding matrix = LM head self.lm_head.weight = self.embedding.weight self._post_init() def _post_init(self): """Initialize weights.""" for name, param in self.named_parameters(): if "weight" in name and param.dim() >= 2: nn.init.xavier_uniform_(param) elif "bias" in name: nn.init.zeros_(param) def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, return_stats: bool = False): """ Args: input_ids: (batch, seq_len) token indices attention_mask: (batch, seq_len) optional padding mask return_stats: return per-block statistics Returns: logits: (batch, seq_len, vocab_size) stats: list of per-block stats dicts (if return_stats=True) """ x = self.embedding(input_ids) x = self.pos_encoding(x) all_stats = [] for block in self.blocks: x, stats = block(x, mask=attention_mask) all_stats.append(stats) x = self.ln_f(x) logits = self.lm_head(x) if return_stats: return logits, all_stats return logits @torch.no_grad() def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20, temperature: float = 1.0, top_k: int = 50) -> torch.Tensor: """Simple autoregressive generation.""" self.eval() for _ in range(max_new_tokens): if input_ids.size(1) > self.config.max_seq_len: input_ids = input_ids[:, -self.config.max_seq_len:] logits = self(input_ids) logits = logits[:, -1, :] / temperature if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1) input_ids = torch.cat([input_ids, next_token], dim=-1) return input_ids def reset_schedulers(self): """Reset all rank schedulers and quantum routers.""" for block in self.blocks: block.reset_scheduler() @property def stats(self) -> Dict: """Runtime statistics across all blocks.""" stats = { "total_params": self.total_params, "tt_params": self.tt_params, "compression_ratio": self.compression_ratio, "rank_history": {}, "quantum_usage": {}, } for i, block in enumerate(self.blocks): stats["rank_history"][i] = block.rank_scheduler.current_rank if block.quantum_router is not None: stats["quantum_usage"][i] = block.quantum_router.usage_percent return stats @property def total_params(self) -> int: return sum(p.numel() for p in self.parameters()) @property def trainable_params(self) -> int: return sum(p.numel() for p in self.parameters() if p.requires_grad) @property def tt_params(self) -> int: """Count only TT-decomposed parameters.""" count = 0 for block in self.blocks: for core in block.tt_ffn.up_proj.cores: count += core.numel() for core in block.tt_ffn.down_proj.cores: count += core.numel() return count @property def compression_ratio(self) -> float: """Estimated compression ratio vs. dense equivalent.""" dense_per_block = 2 * self.config.d_model * self.config.d_model * self.config.ff_multiplier base = self.total_params - self.tt_params tt = self.tt_params return (base + dense_per_block * self.config.n_layers) / max(base + tt, 1) def flops_estimate(self, batch_size: int = 1, seq_len: int = None) -> Dict: """Estimate total FLOPs.""" T = seq_len or self.config.max_seq_len total = 0 breakdown = {} for i, block in enumerate(self.blocks): b = block.flops_estimate(batch_size, T) total += b["total"] breakdown[f"block_{i}"] = b return {"total": total, "breakdown": breakdown} class DenseBaseline(nn.Module): """ Standard transformer baseline — no TT, no quantum. Same hyperparameters as QTensorFormer for fair comparison. """ def __init__(self, config: ModelConfig): super().__init__() self.config = config self.embedding = nn.Embedding(config.vocab_size, config.d_model) self.pos_encoding = PositionalEncoding( config.d_model, config.max_seq_len, config.dropout ) self.blocks = nn.ModuleList([ nn.ModuleDict({ "ln1": nn.LayerNorm(config.d_model), "attn": nn.MultiheadAttention( config.d_model, config.n_heads, dropout=config.dropout, batch_first=True ), "ln2": nn.LayerNorm(config.d_model), "ffn": nn.Sequential( nn.Linear(config.d_model, config.d_model * config.ff_multiplier), nn.GELU(), nn.Linear(config.d_model * config.ff_multiplier, config.d_model), ), "dropout": nn.Dropout(config.dropout), }) for _ in range(config.n_layers) ]) self.ln_f = nn.LayerNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.embedding.weight def forward(self, input_ids, attention_mask=None, return_stats=False): x = self.embedding(input_ids) x = self.pos_encoding(x) for block in self.blocks: attn_out, _ = block["attn"]( block["ln1"](x), block["ln1"](x), block["ln1"](x), key_padding_mask=attention_mask, need_weights=False ) x = x + block["dropout"](attn_out) ffn_out = block["ffn"](block["ln2"](x)) x = x + block["dropout"](ffn_out) x = self.ln_f(x) logits = self.lm_head(x) if return_stats: return logits, [] return logits @property def total_params(self) -> int: return sum(p.numel() for p in self.parameters()) def create_model(config: ModelConfig, model_type: str = "qtensor") -> nn.Module: """ Factory for model creation. Args: config: ModelConfig instance. model_type: 'qtensor', 'tensor_only' (no quantum), 'dense' (baseline), 'distilled' (knowledge-distilled compact). Returns: nn.Module instance. """ if model_type == "qtensor": config.use_quantum = True return QTensorFormer(config) elif model_type == "tensor_only": config.use_quantum = False return QTensorFormer(config) elif model_type == "dense": return DenseBaseline(config) elif model_type == "distilled": config.use_quantum = True config.tt_rank = max(2, config.tt_rank // 2) # More aggressively compressed return QTensorFormer(config) else: raise ValueError(f"Unknown model_type: {model_type}")