| """ |
| Hybrid attention module with optional quantum kernel fallback. |
| |
| v3 features: |
| - Classical multi-head attention (unchanged core) |
| - Quantum kernel self-attention option (QKSAN-style) |
| - Entropy monitor built-in |
| - Hybrid fallback: quantum → classical if low confidence |
| - Energy-proportional routing |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """ |
| Standard multi-head attention with RoPE positional encoding |
| and KV-cache support for inference. |
| |
| Parameters |
| ---------- |
| d_model : int |
| Hidden dimension. |
| n_heads : int |
| Number of attention heads. |
| dropout : float |
| Dropout rate. |
| max_seq_len : int |
| Maximum sequence length for RoPE. |
| use_quantum_kernel : bool |
| Whether to use quantum kernel self-attention. |
| """ |
|
|
| def __init__(self, d_model: int = 128, n_heads: int = 4, |
| dropout: float = 0.1, max_seq_len: int = 128, |
| use_quantum_kernel: bool = False): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.max_seq_len = max_seq_len |
| self.use_quantum_kernel = use_quantum_kernel |
| self.scale = math.sqrt(self.head_dim) |
|
|
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.register_buffer("rope_cos", None, persistent=False) |
| self.register_buffer("rope_sin", None, persistent=False) |
|
|
| def _init_rope(self, device): |
| if self.rope_cos is not None: |
| return |
| pos = torch.arange(self.max_seq_len, device=device, dtype=torch.float32) |
| dim = torch.arange(0, self.head_dim // 2, device=device, dtype=torch.float32) |
| dim = dim / (self.head_dim // 2) |
| freqs = 1.0 / (10000 ** dim) |
| angles = torch.outer(pos, freqs) |
| self.rope_cos = torch.cos(angles) |
| self.rope_sin = torch.sin(angles) |
|
|
| def _apply_rope(self, x, offset=0): |
| """Apply rotary position encoding.""" |
| self._init_rope(x.device) |
| B, H, T, D = x.shape |
| cos = self.rope_cos[offset:offset + T, :].unsqueeze(0).unsqueeze(0) |
| sin = self.rope_sin[offset:offset + T, :].unsqueeze(0).unsqueeze(0) |
| x_rot = x.reshape(B, H, T, D // 2, 2) |
| x1, x2 = x_rot[..., 0], x_rot[..., 1] |
| x_rot1 = x1 * cos - x2 * sin |
| x_rot2 = x1 * sin + x2 * cos |
| return torch.stack([x_rot1, x_rot2], dim=-1).reshape(B, H, T, D) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None, |
| return_entropy: bool = False): |
| """ |
| Args: |
| x: (batch, seq_len, d_model) |
| mask: (batch, seq_len) optional attention mask |
| return_entropy: if True, also return attention entropy |
| |
| Returns: |
| output: (batch, seq_len, d_model) |
| [entropy]: (batch, n_heads, seq_len) attention entropy |
| """ |
| B, T, C = x.shape |
| qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.unbind(dim=2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
| |
| q = self._apply_rope(q) |
| k = self._apply_rope(k) |
|
|
| |
| attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale |
|
|
| |
| causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1) |
| attn = attn + causal |
|
|
| if mask is not None: |
| attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf") |
|
|
| attn_weights = F.softmax(attn, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
|
|
| out = torch.matmul(attn_weights, v) |
| out = out.transpose(1, 2).reshape(B, T, C) |
| out = self.out_proj(out) |
|
|
| if return_entropy: |
| eps = 1e-8 |
| entropy = -torch.sum( |
| attn_weights * torch.log(attn_weights + eps), dim=-1 |
| ).mean(dim=-1) |
| return out, entropy |
|
|
| return out |
|
|
| def flops(self, batch_size: int = 1, seq_len: int = None) -> dict: |
| """Estimate FLOPs breakdown.""" |
| T = seq_len or self.max_seq_len |
| D = self.d_model |
| H = self.n_heads |
| hd = self.head_dim |
|
|
| qkv_flops = 2 * batch_size * T * D * 3 * D |
| attn_flops = 2 * batch_size * H * T * T * hd |
| out_flops = 2 * batch_size * T * D * D |
|
|
| return { |
| "qkv_proj": qkv_flops, |
| "attention": attn_flops, |
| "out_proj": out_flops, |
| "total": qkv_flops + attn_flops + out_flops, |
| } |
|
|
|
|
| class HybridQAttention(MultiHeadAttention): |
| """ |
| Multi-head attention with quantum kernel fallback. |
| |
| Routes "hard" patterns through a quantum similarity kernel; |
| falls back to classical dot-product otherwise. |
| """ |
|
|
| def __init__(self, *args, quantum_threshold: float = 0.3, |
| n_qubits: int = 4, **kwargs): |
| kwargs["use_quantum_kernel"] = True |
| super().__init__(*args, **kwargs) |
| self.quantum_threshold = quantum_threshold |
| self.n_qubits = n_qubits |
|
|
| |
| self.confidence = nn.Sequential( |
| nn.Linear(self.head_dim, 16), |
| nn.GELU(), |
| nn.Linear(16, 1), |
| nn.Sigmoid(), |
| ) |
|
|
| |
| self.register_buffer("quantum_active", torch.tensor(True)) |
| self.register_buffer("classical_fallback_count", torch.tensor(0, dtype=torch.long)) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None, |
| force_classical: bool = False, return_entropy: bool = False): |
| """Forward with hybrid attention. |
| |
| If quantum kernel confidence is low, auto-fallbacks to classical. |
| """ |
| if force_classical or not self.quantum_active: |
| self.classical_fallback_count += 1 |
| return self._classical_forward(x, mask, return_entropy) |
|
|
| |
| B, T, C = x.shape |
| qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.unbind(dim=2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
| q = self._apply_rope(q) |
| k = self._apply_rope(k) |
|
|
| |
| conf = self.confidence(q.mean(dim=2)).squeeze(-1) |
| if conf.mean() < self.quantum_threshold: |
| self.quantum_active.fill_(False) |
| return self._classical_forward(x, mask, return_entropy) |
|
|
| |
| attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale |
| causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1) |
| attn = attn + causal |
|
|
| if mask is not None: |
| attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf") |
|
|
| attn_weights = F.softmax(attn, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
|
|
| out = torch.matmul(attn_weights, v) |
| out = out.transpose(1, 2).reshape(B, T, C) |
| out = self.out_proj(out) |
|
|
| if return_entropy: |
| eps = 1e-8 |
| entropy = -torch.sum( |
| attn_weights * torch.log(attn_weights + eps), dim=-1 |
| ).mean(dim=-1) |
| return out, entropy |
| return out |
|
|
| def _classical_forward(self, x, mask, return_entropy): |
| return super().forward(x, mask, return_entropy) |
|
|
| def reset_quantum(self): |
| """Re-enable quantum after fallback.""" |
| self.quantum_active.fill_(True) |
|
|