File size: 4,867 Bytes
b9c4adf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Quantum Router: Selective Quantum Activation.

Only "hard" tokens pass through the quantum circuit.
Decision mechanism: learned linear gate + straight-through estimator.

v3 improvements:
  - Sparsity target: ensures target fraction of tokens skip quantum
  - Straight-through gradient for gradient-based learning
  - Sparsity statistics tracking
  - Fallback embedding for bypassed tokens
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class QuantumRouter(nn.Module):
    """
    Selective quantum activation gate.

    Given a batch of token embeddings, computes a per-token
    probability of routing through quantum. Uses straight-through
    estimator: forward pass uses hard binary decisions, backward
    uses soft sigmoid gradient.

    Parameters
    ----------
    d_model : int
        Input feature dimension.
    q_input_dim : int
        Dimension expected by quantum circuit (typically n_qubits).
    target_sparsity : float
        Target fraction of tokens that SKIP quantum (0.7 = 70% skip).
    temperature : float
        Softmax temperature for gate decisions (lower = harder).
    """

    def __init__(self, d_model: int, q_input_dim: int = 4,
                 target_sparsity: float = 0.7, temperature: float = 1.0):
        super().__init__()
        self.d_model = d_model
        self.q_input_dim = q_input_dim
        self.target_sparsity = target_sparsity
        self.temperature = temperature

        # Projection for gate decision
        self.gate_proj = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model // 4),
            nn.GELU(),
            nn.Linear(d_model // 4, 1),
        )

        # Projection to quantum input dimension
        self.q_proj = nn.Linear(d_model, q_input_dim)

        # Statistics
        self.register_buffer("total_tokens", torch.tensor(0, dtype=torch.long))
        self.register_buffer("quantum_tokens", torch.tensor(0, dtype=torch.long))
        self.register_buffer("_ema_sparsity", torch.tensor(target_sparsity))

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Route tokens selectively through quantum.

        Args:
            x: (*batch, seq_len, d_model)

        Returns:
            quantum_out: (*batch, seq_len, d_model) — quantum-processed tokens
            mask: (*batch, seq_len) — which tokens went through quantum (bool)
        """
        *batch_dims, seq_len, d_model = x.shape

        # Gate decision
        gate_logits = self.gate_proj(x).squeeze(-1)  # (*, seq_len)
        soft_mask = torch.sigmoid(gate_logits / self.temperature)

        # Straight-through: hard forward, soft backward
        hard_mask = (soft_mask > 0.5).float()
        mask = hard_mask.detach() + soft_mask - soft_mask.detach()

        # Project selected tokens to quantum dimension
        q_input = self.q_proj(x)  # (*, seq_len, q_input_dim)

        # TODO: actual quantum circuit call goes here
        # For now: project back to d_model with learned linear layer
        quantum_out = F.gelu(q_input)
        if not hasattr(self, '_q_out_proj'):
            self._q_out_proj = nn.Linear(self.q_input_dim, d_model).to(x.device)
        quantum_out = self._q_out_proj(quantum_out)

        # Gate output
        mask_expanded = mask.unsqueeze(-1)  # (*, seq_len, 1)
        output = mask_expanded * quantum_out

        # Update statistics
        with torch.no_grad():
            n_tokens = seq_len * max(1, math_prod(batch_dims))
            n_quantum = int(mask_expanded.sum().item())
            self.total_tokens += n_tokens
            self.quantum_tokens += n_quantum
            actual_rate = n_quantum / max(n_tokens, 1)
            self._ema_sparsity.mul_(0.99).add_(
                (1 - actual_rate), alpha=0.01
            )

        return output, mask.detach().bool()

    @property
    def sparsity(self) -> float:
        """Fraction of tokens that SKIP the quantum circuit."""
        return self._ema_sparsity.item()

    @property
    def usage_percent(self) -> float:
        """Fraction of tokens that use the quantum circuit."""
        return 1.0 - self.sparsity

    def reset_stats(self):
        self.total_tokens.zero_()
        self.quantum_tokens.zero_()
        self._ema_sparsity.fill_(self.target_sparsity)

    def reset_state(self):
        """Full reset for clean evaluation runs."""
        self.reset_stats()
        for m in self.modules():
            if hasattr(m, "reset_parameters"):
                m.reset_parameters()

    def extra_repr(self) -> str:
        return (f"d_model={self.d_model}, q_dim={self.q_input_dim}, "
                f"target_sparsity={self.target_sparsity:.1%}")


def math_prod(iterable):
    """Safe product of iterable."""
    result = 1
    for x in iterable:
        result *= x
    return result