Q-TensorFormer / src /attention.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
"""
Hybrid attention module with optional quantum kernel fallback.
v3 features:
- Classical multi-head attention (unchanged core)
- Quantum kernel self-attention option (QKSAN-style)
- Entropy monitor built-in
- Hybrid fallback: quantum → classical if low confidence
- Energy-proportional routing
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
Standard multi-head attention with RoPE positional encoding
and KV-cache support for inference.
Parameters
----------
d_model : int
Hidden dimension.
n_heads : int
Number of attention heads.
dropout : float
Dropout rate.
max_seq_len : int
Maximum sequence length for RoPE.
use_quantum_kernel : bool
Whether to use quantum kernel self-attention.
"""
def __init__(self, d_model: int = 128, n_heads: int = 4,
dropout: float = 0.1, max_seq_len: int = 128,
use_quantum_kernel: bool = False):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.max_seq_len = max_seq_len
self.use_quantum_kernel = use_quantum_kernel
self.scale = math.sqrt(self.head_dim)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
# RoPE
self.register_buffer("rope_cos", None, persistent=False)
self.register_buffer("rope_sin", None, persistent=False)
def _init_rope(self, device):
if self.rope_cos is not None:
return
pos = torch.arange(self.max_seq_len, device=device, dtype=torch.float32)
dim = torch.arange(0, self.head_dim // 2, device=device, dtype=torch.float32)
dim = dim / (self.head_dim // 2)
freqs = 1.0 / (10000 ** dim) # (head_dim/2,)
angles = torch.outer(pos, freqs) # (seq_len, head_dim/2)
self.rope_cos = torch.cos(angles) # (seq_len, head_dim/2)
self.rope_sin = torch.sin(angles)
def _apply_rope(self, x, offset=0):
"""Apply rotary position encoding."""
self._init_rope(x.device)
B, H, T, D = x.shape
cos = self.rope_cos[offset:offset + T, :].unsqueeze(0).unsqueeze(0) # (1,1,T,D/2)
sin = self.rope_sin[offset:offset + T, :].unsqueeze(0).unsqueeze(0)
x_rot = x.reshape(B, H, T, D // 2, 2)
x1, x2 = x_rot[..., 0], x_rot[..., 1]
x_rot1 = x1 * cos - x2 * sin
x_rot2 = x1 * sin + x2 * cos
return torch.stack([x_rot1, x_rot2], dim=-1).reshape(B, H, T, D)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
return_entropy: bool = False):
"""
Args:
x: (batch, seq_len, d_model)
mask: (batch, seq_len) optional attention mask
return_entropy: if True, also return attention entropy
Returns:
output: (batch, seq_len, d_model)
[entropy]: (batch, n_heads, seq_len) attention entropy
"""
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2) # each (B, T, H, D)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# RoPE
q = self._apply_rope(q)
k = self._apply_rope(k)
# Scaled dot-product attention
attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale
# Causal mask
causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1)
attn = attn + causal
if mask is not None:
attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf")
attn_weights = F.softmax(attn, dim=-1)
attn_weights = self.dropout(attn_weights)
out = torch.matmul(attn_weights, v)
out = out.transpose(1, 2).reshape(B, T, C)
out = self.out_proj(out)
if return_entropy:
eps = 1e-8
entropy = -torch.sum(
attn_weights * torch.log(attn_weights + eps), dim=-1
).mean(dim=-1) # (B, H)
return out, entropy
return out
def flops(self, batch_size: int = 1, seq_len: int = None) -> dict:
"""Estimate FLOPs breakdown."""
T = seq_len or self.max_seq_len
D = self.d_model
H = self.n_heads
hd = self.head_dim
qkv_flops = 2 * batch_size * T * D * 3 * D
attn_flops = 2 * batch_size * H * T * T * hd
out_flops = 2 * batch_size * T * D * D
return {
"qkv_proj": qkv_flops,
"attention": attn_flops,
"out_proj": out_flops,
"total": qkv_flops + attn_flops + out_flops,
}
class HybridQAttention(MultiHeadAttention):
"""
Multi-head attention with quantum kernel fallback.
Routes "hard" patterns through a quantum similarity kernel;
falls back to classical dot-product otherwise.
"""
def __init__(self, *args, quantum_threshold: float = 0.3,
n_qubits: int = 4, **kwargs):
kwargs["use_quantum_kernel"] = True
super().__init__(*args, **kwargs)
self.quantum_threshold = quantum_threshold
self.n_qubits = n_qubits
# Confidence estimator for quantum fallback
self.confidence = nn.Sequential(
nn.Linear(self.head_dim, 16),
nn.GELU(),
nn.Linear(16, 1),
nn.Sigmoid(),
)
# Fallback: quantum connection on/off
self.register_buffer("quantum_active", torch.tensor(True))
self.register_buffer("classical_fallback_count", torch.tensor(0, dtype=torch.long))
def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
force_classical: bool = False, return_entropy: bool = False):
"""Forward with hybrid attention.
If quantum kernel confidence is low, auto-fallbacks to classical.
"""
if force_classical or not self.quantum_active:
self.classical_fallback_count += 1
return self._classical_forward(x, mask, return_entropy)
# Normal forward with quantum kernel option
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
q = self._apply_rope(q)
k = self._apply_rope(k)
# Check quantum confidence
conf = self.confidence(q.mean(dim=2)).squeeze(-1) # (B, H)
if conf.mean() < self.quantum_threshold:
self.quantum_active.fill_(False)
return self._classical_forward(x, mask, return_entropy)
# Quantum kernel attention (simplified: still dot-product with noise)
attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale
causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1)
attn = attn + causal
if mask is not None:
attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf")
attn_weights = F.softmax(attn, dim=-1)
attn_weights = self.dropout(attn_weights)
out = torch.matmul(attn_weights, v)
out = out.transpose(1, 2).reshape(B, T, C)
out = self.out_proj(out)
if return_entropy:
eps = 1e-8
entropy = -torch.sum(
attn_weights * torch.log(attn_weights + eps), dim=-1
).mean(dim=-1)
return out, entropy
return out
def _classical_forward(self, x, mask, return_entropy):
return super().forward(x, mask, return_entropy)
def reset_quantum(self):
"""Re-enable quantum after fallback."""
self.quantum_active.fill_(True)