""" Hybrid Transformer Block: Tensor + Quantum + Adaptive. v3 modular design — block can be configured as: - TT-FFN only (pure tensor) - Quantum only - Hybrid (both) - Standard MLP-FFN (baseline) Each block contains: - Multi-Head Attention (with entropy monitoring) - RankScheduler (entropy → TT rank) - QuantumRouter (selective quantum activation) - TTFeedForward (tensor-decomposed FFN) """ import torch import torch.nn as nn from .attention import MultiHeadAttention, HybridQAttention from .tensor_layers import TTFeedForward from .scheduler import RankScheduler, BudgetAwareScheduler from .router import QuantumRouter class HybridBlock(nn.Module): """ A single Q-TensorFormer block. Flow: x → LayerNorm → Attention + Entropy → RankScheduler: adjust TT ranks → LayerNorm → QuantumRouter (gate) → TTFeedForward (tensor-decomposed) → residual connection """ def __init__(self, d_model: int = 128, n_heads: int = 4, ff_multiplier: int = 4, tt_rank: int = 8, tt_min_rank: int = 2, use_quantum: bool = True, n_qubits: int = 4, n_quantum_layers: int = 2, quantum_sparsity: float = 0.7, rank_alpha: float = 2.0, rank_smoothing: float = 0.9, dropout: float = 0.1, max_seq_len: int = 128): super().__init__() self.d_model = d_model self.use_quantum = use_quantum self.is_hybrid = use_quantum # Flag for model-level detection # Attention self.attention = MultiHeadAttention( d_model, n_heads, dropout, max_seq_len, use_quantum_kernel=False ) # Layer norms self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) # Rank scheduler self.rank_scheduler = RankScheduler( r_min=tt_min_rank, r_max=tt_rank, alpha=rank_alpha, smoothing=rank_smoothing ) # Quantum router if use_quantum: self.quantum_router = QuantumRouter( d_model=d_model, q_input_dim=n_qubits, target_sparsity=quantum_sparsity, ) else: self.quantum_router = None # Tensor-Train FFN self.tt_ffn = TTFeedForward( hidden_dim=d_model, ff_multiplier=ff_multiplier, rank=tt_rank, ) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: torch.Tensor = None): """ Args: x: (batch, seq_len, d_model) mask: (batch, seq_len) optional padding mask Returns: output: (batch, seq_len, d_model) stats: dict with entropy, rank, quantum_usage """ stats = {} # Attention sublayer attn_out, entropy = self.attention( self.ln1(x), mask=mask, return_entropy=True ) x = x + self.dropout(attn_out) # Schedule rank from attention entropy mean_entropy = entropy.mean() if entropy.dim() > 0 else entropy new_rank = self.rank_scheduler(mean_entropy, seq_len=x.shape[1]) self.tt_ffn.set_rank(new_rank) stats["entropy"] = mean_entropy.item() stats["rank"] = new_rank # FFN sublayer normed = self.ln2(x) # Quantum routing quantum_out = torch.zeros_like(normed) if self.quantum_router is not None: quantum_out, q_mask = self.quantum_router(normed) stats["quantum_usage"] = self.quantum_router.usage_percent stats["quantum_sparsity"] = self.quantum_router.sparsity # TT feed-forward ffn_out = self.tt_ffn(normed) # Combine: quantum signal modifies the FFN input combined = normed + self.dropout(ffn_out + quantum_out) x = x + combined return x, stats def set_rank(self, rank: int): """Manually override rank.""" self.tt_ffn.set_rank(rank) def reset_scheduler(self): self.rank_scheduler.reset() if self.quantum_router is not None: self.quantum_router.reset_stats() @property def total_params(self) -> int: return sum(p.numel() for p in self.parameters()) def flops_estimate(self, batch_size: int = 1, seq_len: int = 128) -> dict: """Estimate FLOPs for this block.""" attn_flops = self.attention.flops(batch_size, seq_len)["total"] ffn_flops = self.tt_ffn.flops(batch_size) return { "attention": attn_flops, "tt_ffn": ffn_flops, "total": attn_flops + ffn_flops, }