Q-TensorFormer / src /router.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
"""
Quantum Router: Selective Quantum Activation.
Only "hard" tokens pass through the quantum circuit.
Decision mechanism: learned linear gate + straight-through estimator.
v3 improvements:
- Sparsity target: ensures target fraction of tokens skip quantum
- Straight-through gradient for gradient-based learning
- Sparsity statistics tracking
- Fallback embedding for bypassed tokens
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class QuantumRouter(nn.Module):
"""
Selective quantum activation gate.
Given a batch of token embeddings, computes a per-token
probability of routing through quantum. Uses straight-through
estimator: forward pass uses hard binary decisions, backward
uses soft sigmoid gradient.
Parameters
----------
d_model : int
Input feature dimension.
q_input_dim : int
Dimension expected by quantum circuit (typically n_qubits).
target_sparsity : float
Target fraction of tokens that SKIP quantum (0.7 = 70% skip).
temperature : float
Softmax temperature for gate decisions (lower = harder).
"""
def __init__(self, d_model: int, q_input_dim: int = 4,
target_sparsity: float = 0.7, temperature: float = 1.0):
super().__init__()
self.d_model = d_model
self.q_input_dim = q_input_dim
self.target_sparsity = target_sparsity
self.temperature = temperature
# Projection for gate decision
self.gate_proj = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, d_model // 4),
nn.GELU(),
nn.Linear(d_model // 4, 1),
)
# Projection to quantum input dimension
self.q_proj = nn.Linear(d_model, q_input_dim)
# Statistics
self.register_buffer("total_tokens", torch.tensor(0, dtype=torch.long))
self.register_buffer("quantum_tokens", torch.tensor(0, dtype=torch.long))
self.register_buffer("_ema_sparsity", torch.tensor(target_sparsity))
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route tokens selectively through quantum.
Args:
x: (*batch, seq_len, d_model)
Returns:
quantum_out: (*batch, seq_len, d_model) — quantum-processed tokens
mask: (*batch, seq_len) — which tokens went through quantum (bool)
"""
*batch_dims, seq_len, d_model = x.shape
# Gate decision
gate_logits = self.gate_proj(x).squeeze(-1) # (*, seq_len)
soft_mask = torch.sigmoid(gate_logits / self.temperature)
# Straight-through: hard forward, soft backward
hard_mask = (soft_mask > 0.5).float()
mask = hard_mask.detach() + soft_mask - soft_mask.detach()
# Project selected tokens to quantum dimension
q_input = self.q_proj(x) # (*, seq_len, q_input_dim)
# TODO: actual quantum circuit call goes here
# For now: project back to d_model with learned linear layer
quantum_out = F.gelu(q_input)
if not hasattr(self, '_q_out_proj'):
self._q_out_proj = nn.Linear(self.q_input_dim, d_model).to(x.device)
quantum_out = self._q_out_proj(quantum_out)
# Gate output
mask_expanded = mask.unsqueeze(-1) # (*, seq_len, 1)
output = mask_expanded * quantum_out
# Update statistics
with torch.no_grad():
n_tokens = seq_len * max(1, math_prod(batch_dims))
n_quantum = int(mask_expanded.sum().item())
self.total_tokens += n_tokens
self.quantum_tokens += n_quantum
actual_rate = n_quantum / max(n_tokens, 1)
self._ema_sparsity.mul_(0.99).add_(
(1 - actual_rate), alpha=0.01
)
return output, mask.detach().bool()
@property
def sparsity(self) -> float:
"""Fraction of tokens that SKIP the quantum circuit."""
return self._ema_sparsity.item()
@property
def usage_percent(self) -> float:
"""Fraction of tokens that use the quantum circuit."""
return 1.0 - self.sparsity
def reset_stats(self):
self.total_tokens.zero_()
self.quantum_tokens.zero_()
self._ema_sparsity.fill_(self.target_sparsity)
def reset_state(self):
"""Full reset for clean evaluation runs."""
self.reset_stats()
for m in self.modules():
if hasattr(m, "reset_parameters"):
m.reset_parameters()
def extra_repr(self) -> str:
return (f"d_model={self.d_model}, q_dim={self.q_input_dim}, "
f"target_sparsity={self.target_sparsity:.1%}")
def math_prod(iterable):
"""Safe product of iterable."""
result = 1
for x in iterable:
result *= x
return result