File size: 4,867 Bytes
b9c4adf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """
Quantum Router: Selective Quantum Activation.
Only "hard" tokens pass through the quantum circuit.
Decision mechanism: learned linear gate + straight-through estimator.
v3 improvements:
- Sparsity target: ensures target fraction of tokens skip quantum
- Straight-through gradient for gradient-based learning
- Sparsity statistics tracking
- Fallback embedding for bypassed tokens
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class QuantumRouter(nn.Module):
"""
Selective quantum activation gate.
Given a batch of token embeddings, computes a per-token
probability of routing through quantum. Uses straight-through
estimator: forward pass uses hard binary decisions, backward
uses soft sigmoid gradient.
Parameters
----------
d_model : int
Input feature dimension.
q_input_dim : int
Dimension expected by quantum circuit (typically n_qubits).
target_sparsity : float
Target fraction of tokens that SKIP quantum (0.7 = 70% skip).
temperature : float
Softmax temperature for gate decisions (lower = harder).
"""
def __init__(self, d_model: int, q_input_dim: int = 4,
target_sparsity: float = 0.7, temperature: float = 1.0):
super().__init__()
self.d_model = d_model
self.q_input_dim = q_input_dim
self.target_sparsity = target_sparsity
self.temperature = temperature
# Projection for gate decision
self.gate_proj = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model // 4),
nn.GELU(),
nn.Linear(d_model // 4, 1),
)
# Projection to quantum input dimension
self.q_proj = nn.Linear(d_model, q_input_dim)
# Statistics
self.register_buffer("total_tokens", torch.tensor(0, dtype=torch.long))
self.register_buffer("quantum_tokens", torch.tensor(0, dtype=torch.long))
self.register_buffer("_ema_sparsity", torch.tensor(target_sparsity))
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route tokens selectively through quantum.
Args:
x: (*batch, seq_len, d_model)
Returns:
quantum_out: (*batch, seq_len, d_model) — quantum-processed tokens
mask: (*batch, seq_len) — which tokens went through quantum (bool)
"""
*batch_dims, seq_len, d_model = x.shape
# Gate decision
gate_logits = self.gate_proj(x).squeeze(-1) # (*, seq_len)
soft_mask = torch.sigmoid(gate_logits / self.temperature)
# Straight-through: hard forward, soft backward
hard_mask = (soft_mask > 0.5).float()
mask = hard_mask.detach() + soft_mask - soft_mask.detach()
# Project selected tokens to quantum dimension
q_input = self.q_proj(x) # (*, seq_len, q_input_dim)
# TODO: actual quantum circuit call goes here
# For now: project back to d_model with learned linear layer
quantum_out = F.gelu(q_input)
if not hasattr(self, '_q_out_proj'):
self._q_out_proj = nn.Linear(self.q_input_dim, d_model).to(x.device)
quantum_out = self._q_out_proj(quantum_out)
# Gate output
mask_expanded = mask.unsqueeze(-1) # (*, seq_len, 1)
output = mask_expanded * quantum_out
# Update statistics
with torch.no_grad():
n_tokens = seq_len * max(1, math_prod(batch_dims))
n_quantum = int(mask_expanded.sum().item())
self.total_tokens += n_tokens
self.quantum_tokens += n_quantum
actual_rate = n_quantum / max(n_tokens, 1)
self._ema_sparsity.mul_(0.99).add_(
(1 - actual_rate), alpha=0.01
)
return output, mask.detach().bool()
@property
def sparsity(self) -> float:
"""Fraction of tokens that SKIP the quantum circuit."""
return self._ema_sparsity.item()
@property
def usage_percent(self) -> float:
"""Fraction of tokens that use the quantum circuit."""
return 1.0 - self.sparsity
def reset_stats(self):
self.total_tokens.zero_()
self.quantum_tokens.zero_()
self._ema_sparsity.fill_(self.target_sparsity)
def reset_state(self):
"""Full reset for clean evaluation runs."""
self.reset_stats()
for m in self.modules():
if hasattr(m, "reset_parameters"):
m.reset_parameters()
def extra_repr(self) -> str:
return (f"d_model={self.d_model}, q_dim={self.q_input_dim}, "
f"target_sparsity={self.target_sparsity:.1%}")
def math_prod(iterable):
"""Safe product of iterable."""
result = 1
for x in iterable:
result *= x
return result
|