""" Hybrid attention module with optional quantum kernel fallback. v3 features: - Classical multi-head attention (unchanged core) - Quantum kernel self-attention option (QKSAN-style) - Entropy monitor built-in - Hybrid fallback: quantum → classical if low confidence - Energy-proportional routing """ import torch import torch.nn as nn import torch.nn.functional as F import math class MultiHeadAttention(nn.Module): """ Standard multi-head attention with RoPE positional encoding and KV-cache support for inference. Parameters ---------- d_model : int Hidden dimension. n_heads : int Number of attention heads. dropout : float Dropout rate. max_seq_len : int Maximum sequence length for RoPE. use_quantum_kernel : bool Whether to use quantum kernel self-attention. """ def __init__(self, d_model: int = 128, n_heads: int = 4, dropout: float = 0.1, max_seq_len: int = 128, use_quantum_kernel: bool = False): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads self.max_seq_len = max_seq_len self.use_quantum_kernel = use_quantum_kernel self.scale = math.sqrt(self.head_dim) self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) # RoPE self.register_buffer("rope_cos", None, persistent=False) self.register_buffer("rope_sin", None, persistent=False) def _init_rope(self, device): if self.rope_cos is not None: return pos = torch.arange(self.max_seq_len, device=device, dtype=torch.float32) dim = torch.arange(0, self.head_dim // 2, device=device, dtype=torch.float32) dim = dim / (self.head_dim // 2) freqs = 1.0 / (10000 ** dim) # (head_dim/2,) angles = torch.outer(pos, freqs) # (seq_len, head_dim/2) self.rope_cos = torch.cos(angles) # (seq_len, head_dim/2) self.rope_sin = torch.sin(angles) def _apply_rope(self, x, offset=0): """Apply rotary position encoding.""" self._init_rope(x.device) B, H, T, D = x.shape cos = self.rope_cos[offset:offset + T, :].unsqueeze(0).unsqueeze(0) # (1,1,T,D/2) sin = self.rope_sin[offset:offset + T, :].unsqueeze(0).unsqueeze(0) x_rot = x.reshape(B, H, T, D // 2, 2) x1, x2 = x_rot[..., 0], x_rot[..., 1] x_rot1 = x1 * cos - x2 * sin x_rot2 = x1 * sin + x2 * cos return torch.stack([x_rot1, x_rot2], dim=-1).reshape(B, H, T, D) def forward(self, x: torch.Tensor, mask: torch.Tensor = None, return_entropy: bool = False): """ Args: x: (batch, seq_len, d_model) mask: (batch, seq_len) optional attention mask return_entropy: if True, also return attention entropy Returns: output: (batch, seq_len, d_model) [entropy]: (batch, n_heads, seq_len) attention entropy """ B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) q, k, v = qkv.unbind(dim=2) # each (B, T, H, D) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # RoPE q = self._apply_rope(q) k = self._apply_rope(k) # Scaled dot-product attention attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale # Causal mask causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1) attn = attn + causal if mask is not None: attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf") attn_weights = F.softmax(attn, dim=-1) attn_weights = self.dropout(attn_weights) out = torch.matmul(attn_weights, v) out = out.transpose(1, 2).reshape(B, T, C) out = self.out_proj(out) if return_entropy: eps = 1e-8 entropy = -torch.sum( attn_weights * torch.log(attn_weights + eps), dim=-1 ).mean(dim=-1) # (B, H) return out, entropy return out def flops(self, batch_size: int = 1, seq_len: int = None) -> dict: """Estimate FLOPs breakdown.""" T = seq_len or self.max_seq_len D = self.d_model H = self.n_heads hd = self.head_dim qkv_flops = 2 * batch_size * T * D * 3 * D attn_flops = 2 * batch_size * H * T * T * hd out_flops = 2 * batch_size * T * D * D return { "qkv_proj": qkv_flops, "attention": attn_flops, "out_proj": out_flops, "total": qkv_flops + attn_flops + out_flops, } class HybridQAttention(MultiHeadAttention): """ Multi-head attention with quantum kernel fallback. Routes "hard" patterns through a quantum similarity kernel; falls back to classical dot-product otherwise. """ def __init__(self, *args, quantum_threshold: float = 0.3, n_qubits: int = 4, **kwargs): kwargs["use_quantum_kernel"] = True super().__init__(*args, **kwargs) self.quantum_threshold = quantum_threshold self.n_qubits = n_qubits # Confidence estimator for quantum fallback self.confidence = nn.Sequential( nn.Linear(self.head_dim, 16), nn.GELU(), nn.Linear(16, 1), nn.Sigmoid(), ) # Fallback: quantum connection on/off self.register_buffer("quantum_active", torch.tensor(True)) self.register_buffer("classical_fallback_count", torch.tensor(0, dtype=torch.long)) def forward(self, x: torch.Tensor, mask: torch.Tensor = None, force_classical: bool = False, return_entropy: bool = False): """Forward with hybrid attention. If quantum kernel confidence is low, auto-fallbacks to classical. """ if force_classical or not self.quantum_active: self.classical_fallback_count += 1 return self._classical_forward(x, mask, return_entropy) # Normal forward with quantum kernel option B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) q, k, v = qkv.unbind(dim=2) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) q = self._apply_rope(q) k = self._apply_rope(k) # Check quantum confidence conf = self.confidence(q.mean(dim=2)).squeeze(-1) # (B, H) if conf.mean() < self.quantum_threshold: self.quantum_active.fill_(False) return self._classical_forward(x, mask, return_entropy) # Quantum kernel attention (simplified: still dot-product with noise) attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1) attn = attn + causal if mask is not None: attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf") attn_weights = F.softmax(attn, dim=-1) attn_weights = self.dropout(attn_weights) out = torch.matmul(attn_weights, v) out = out.transpose(1, 2).reshape(B, T, C) out = self.out_proj(out) if return_entropy: eps = 1e-8 entropy = -torch.sum( attn_weights * torch.log(attn_weights + eps), dim=-1 ).mean(dim=-1) return out, entropy return out def _classical_forward(self, x, mask, return_entropy): return super().forward(x, mask, return_entropy) def reset_quantum(self): """Re-enable quantum after fallback.""" self.quantum_active.fill_(True)