""" Quantum Router: Selective Quantum Activation. Only "hard" tokens pass through the quantum circuit. Decision mechanism: learned linear gate + straight-through estimator. v3 improvements: - Sparsity target: ensures target fraction of tokens skip quantum - Straight-through gradient for gradient-based learning - Sparsity statistics tracking - Fallback embedding for bypassed tokens """ import torch import torch.nn as nn import torch.nn.functional as F class QuantumRouter(nn.Module): """ Selective quantum activation gate. Given a batch of token embeddings, computes a per-token probability of routing through quantum. Uses straight-through estimator: forward pass uses hard binary decisions, backward uses soft sigmoid gradient. Parameters ---------- d_model : int Input feature dimension. q_input_dim : int Dimension expected by quantum circuit (typically n_qubits). target_sparsity : float Target fraction of tokens that SKIP quantum (0.7 = 70% skip). temperature : float Softmax temperature for gate decisions (lower = harder). """ def __init__(self, d_model: int, q_input_dim: int = 4, target_sparsity: float = 0.7, temperature: float = 1.0): super().__init__() self.d_model = d_model self.q_input_dim = q_input_dim self.target_sparsity = target_sparsity self.temperature = temperature # Projection for gate decision self.gate_proj = nn.Sequential( nn.LayerNorm(d_model), nn.Linear(d_model, d_model // 4), nn.GELU(), nn.Linear(d_model // 4, 1), ) # Projection to quantum input dimension self.q_proj = nn.Linear(d_model, q_input_dim) # Statistics self.register_buffer("total_tokens", torch.tensor(0, dtype=torch.long)) self.register_buffer("quantum_tokens", torch.tensor(0, dtype=torch.long)) self.register_buffer("_ema_sparsity", torch.tensor(target_sparsity)) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Route tokens selectively through quantum. Args: x: (*batch, seq_len, d_model) Returns: quantum_out: (*batch, seq_len, d_model) — quantum-processed tokens mask: (*batch, seq_len) — which tokens went through quantum (bool) """ *batch_dims, seq_len, d_model = x.shape # Gate decision gate_logits = self.gate_proj(x).squeeze(-1) # (*, seq_len) soft_mask = torch.sigmoid(gate_logits / self.temperature) # Straight-through: hard forward, soft backward hard_mask = (soft_mask > 0.5).float() mask = hard_mask.detach() + soft_mask - soft_mask.detach() # Project selected tokens to quantum dimension q_input = self.q_proj(x) # (*, seq_len, q_input_dim) # TODO: actual quantum circuit call goes here # For now: project back to d_model with learned linear layer quantum_out = F.gelu(q_input) if not hasattr(self, '_q_out_proj'): self._q_out_proj = nn.Linear(self.q_input_dim, d_model).to(x.device) quantum_out = self._q_out_proj(quantum_out) # Gate output mask_expanded = mask.unsqueeze(-1) # (*, seq_len, 1) output = mask_expanded * quantum_out # Update statistics with torch.no_grad(): n_tokens = seq_len * max(1, math_prod(batch_dims)) n_quantum = int(mask_expanded.sum().item()) self.total_tokens += n_tokens self.quantum_tokens += n_quantum actual_rate = n_quantum / max(n_tokens, 1) self._ema_sparsity.mul_(0.99).add_( (1 - actual_rate), alpha=0.01 ) return output, mask.detach().bool() @property def sparsity(self) -> float: """Fraction of tokens that SKIP the quantum circuit.""" return self._ema_sparsity.item() @property def usage_percent(self) -> float: """Fraction of tokens that use the quantum circuit.""" return 1.0 - self.sparsity def reset_stats(self): self.total_tokens.zero_() self.quantum_tokens.zero_() self._ema_sparsity.fill_(self.target_sparsity) def reset_state(self): """Full reset for clean evaluation runs.""" self.reset_stats() for m in self.modules(): if hasattr(m, "reset_parameters"): m.reset_parameters() def extra_repr(self) -> str: return (f"d_model={self.d_model}, q_dim={self.q_input_dim}, " f"target_sparsity={self.target_sparsity:.1%}") def math_prod(iterable): """Safe product of iterable.""" result = 1 for x in iterable: result *= x return result