| """ |
| Quantum Router: Selective Quantum Activation. |
| |
| Only "hard" tokens pass through the quantum circuit. |
| Decision mechanism: learned linear gate + straight-through estimator. |
| |
| v3 improvements: |
| - Sparsity target: ensures target fraction of tokens skip quantum |
| - Straight-through gradient for gradient-based learning |
| - Sparsity statistics tracking |
| - Fallback embedding for bypassed tokens |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class QuantumRouter(nn.Module): |
| """ |
| Selective quantum activation gate. |
| |
| Given a batch of token embeddings, computes a per-token |
| probability of routing through quantum. Uses straight-through |
| estimator: forward pass uses hard binary decisions, backward |
| uses soft sigmoid gradient. |
| |
| Parameters |
| ---------- |
| d_model : int |
| Input feature dimension. |
| q_input_dim : int |
| Dimension expected by quantum circuit (typically n_qubits). |
| target_sparsity : float |
| Target fraction of tokens that SKIP quantum (0.7 = 70% skip). |
| temperature : float |
| Softmax temperature for gate decisions (lower = harder). |
| """ |
|
|
| def __init__(self, d_model: int, q_input_dim: int = 4, |
| target_sparsity: float = 0.7, temperature: float = 1.0): |
| super().__init__() |
| self.d_model = d_model |
| self.q_input_dim = q_input_dim |
| self.target_sparsity = target_sparsity |
| self.temperature = temperature |
|
|
| |
| self.gate_proj = nn.Sequential( |
| nn.LayerNorm(d_model), |
| nn.Linear(d_model, d_model // 4), |
| nn.GELU(), |
| nn.Linear(d_model // 4, 1), |
| ) |
|
|
| |
| self.q_proj = nn.Linear(d_model, q_input_dim) |
|
|
| |
| self.register_buffer("total_tokens", torch.tensor(0, dtype=torch.long)) |
| self.register_buffer("quantum_tokens", torch.tensor(0, dtype=torch.long)) |
| self.register_buffer("_ema_sparsity", torch.tensor(target_sparsity)) |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Route tokens selectively through quantum. |
| |
| Args: |
| x: (*batch, seq_len, d_model) |
| |
| Returns: |
| quantum_out: (*batch, seq_len, d_model) — quantum-processed tokens |
| mask: (*batch, seq_len) — which tokens went through quantum (bool) |
| """ |
| *batch_dims, seq_len, d_model = x.shape |
|
|
| |
| gate_logits = self.gate_proj(x).squeeze(-1) |
| soft_mask = torch.sigmoid(gate_logits / self.temperature) |
|
|
| |
| hard_mask = (soft_mask > 0.5).float() |
| mask = hard_mask.detach() + soft_mask - soft_mask.detach() |
|
|
| |
| q_input = self.q_proj(x) |
|
|
| |
| |
| quantum_out = F.gelu(q_input) |
| if not hasattr(self, '_q_out_proj'): |
| self._q_out_proj = nn.Linear(self.q_input_dim, d_model).to(x.device) |
| quantum_out = self._q_out_proj(quantum_out) |
|
|
| |
| mask_expanded = mask.unsqueeze(-1) |
| output = mask_expanded * quantum_out |
|
|
| |
| with torch.no_grad(): |
| n_tokens = seq_len * max(1, math_prod(batch_dims)) |
| n_quantum = int(mask_expanded.sum().item()) |
| self.total_tokens += n_tokens |
| self.quantum_tokens += n_quantum |
| actual_rate = n_quantum / max(n_tokens, 1) |
| self._ema_sparsity.mul_(0.99).add_( |
| (1 - actual_rate), alpha=0.01 |
| ) |
|
|
| return output, mask.detach().bool() |
|
|
| @property |
| def sparsity(self) -> float: |
| """Fraction of tokens that SKIP the quantum circuit.""" |
| return self._ema_sparsity.item() |
|
|
| @property |
| def usage_percent(self) -> float: |
| """Fraction of tokens that use the quantum circuit.""" |
| return 1.0 - self.sparsity |
|
|
| def reset_stats(self): |
| self.total_tokens.zero_() |
| self.quantum_tokens.zero_() |
| self._ema_sparsity.fill_(self.target_sparsity) |
|
|
| def reset_state(self): |
| """Full reset for clean evaluation runs.""" |
| self.reset_stats() |
| for m in self.modules(): |
| if hasattr(m, "reset_parameters"): |
| m.reset_parameters() |
|
|
| def extra_repr(self) -> str: |
| return (f"d_model={self.d_model}, q_dim={self.q_input_dim}, " |
| f"target_sparsity={self.target_sparsity:.1%}") |
|
|
|
|
| def math_prod(iterable): |
| """Safe product of iterable.""" |
| result = 1 |
| for x in iterable: |
| result *= x |
| return result |
|
|