| """ |
| Adaptive TT-Rank Scheduler. |
| |
| Core novelty of Q-TensorFormer: adjusts tensor rank dynamically |
| based on per-input complexity, estimated via attention entropy. |
| |
| r(input) = r_min + α × normalized_entropy × (r_max - r_min) |
| |
| Supports: |
| - EMA smoothing to prevent oscillation |
| - Budget-capped ranks |
| - Deterministic rounding with hysteresis |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import math |
|
|
|
|
| class RankScheduler(nn.Module): |
| """ |
| Attention entropy → TT-rank scheduler. |
| |
| Parameters |
| ---------- |
| r_min : int |
| Minimum tensor rank (maximum compression). |
| r_max : int |
| Maximum tensor rank (minimum compression). |
| alpha : float |
| Sensitivity: how much entropy changes the rank. |
| alpha=0 → fixed rank r_min. |
| alpha=1 → rank fully spans r_min to r_max. |
| alpha=2.0 → aggressive scaling (default). |
| smoothing : float |
| EMA decay factor (0.9 = smooth, 0 = no history). |
| """ |
|
|
| def __init__(self, r_min: int = 2, r_max: int = 8, |
| alpha: float = 2.0, smoothing: float = 0.9): |
| super().__init__() |
| self.r_min = r_min |
| self.r_max = r_max |
| self.alpha = alpha |
| self.smoothing = smoothing |
|
|
| self.register_buffer("_ema_entropy", torch.tensor(0.5)) |
| self.register_buffer("_ema_rank", torch.tensor((r_min + r_max) // 2, dtype=torch.float)) |
| self.register_buffer("_counter", torch.tensor(0, dtype=torch.long)) |
|
|
| |
| self.learned_alpha = nn.Parameter(torch.tensor(float(alpha)), requires_grad=False) |
|
|
| def forward(self, entropy: torch.Tensor, seq_len: int = None) -> int: |
| """ |
| Compute rank from attention entropy. |
| |
| Args: |
| entropy: Scalar or 0-dim tensor (mean attention entropy). |
| seq_len: Sequence length for normalization (optional). |
| |
| Returns: |
| Integer tensor rank. |
| """ |
| if entropy.dim() > 0: |
| entropy = entropy.mean() |
|
|
| |
| if seq_len is not None and seq_len > 1: |
| norm_factor = math.log(seq_len) |
| normalized = torch.clamp(entropy / max(norm_factor, 1e-8), 0.0, 1.0) |
| else: |
| normalized = torch.clamp(torch.tanh(entropy / 2.0), 0.0, 1.0) |
|
|
| |
| self._ema_entropy.mul_(self.smoothing).add_(normalized, alpha=1.0 - self.smoothing) |
| smoothed = self._ema_entropy |
|
|
| |
| alpha_val = self.learned_alpha.item() |
| span = self.r_max - self.r_min |
| raw = self.r_min + alpha_val * smoothed.item() * span |
|
|
| |
| self._ema_rank.mul_(0.7).add_(raw, alpha=0.3) |
| rank = int(torch.round(self._ema_rank).item()) |
|
|
| |
| rank = max(self.r_min, min(self.r_max, rank)) |
| self._counter.add_(1) |
| return rank |
|
|
| def reset(self): |
| """Reset EMA state.""" |
| self._ema_entropy.fill_(0.5) |
| self._ema_rank.fill_((self.r_min + self.r_max) / 2.0) |
| self._counter.fill_(0) |
|
|
| @property |
| def current_rank(self) -> float: |
| return self._ema_rank.item() |
|
|
| @property |
| def current_entropy(self) -> float: |
| return self._ema_entropy.item() |
|
|
|
|
| class BudgetAwareScheduler(nn.Module): |
| """ |
| Extends RankScheduler with deployment budget constraints. |
| |
| Automatically caps tensor rank to meet: |
| - Max parameter budget |
| - Max latency target |
| - Max energy per query |
| """ |
|
|
| def __init__(self, scheduler: RankScheduler, |
| max_params: int = None, |
| max_latency_ms: float = None, |
| max_energy_uj: float = None): |
| super().__init__() |
| self.scheduler = scheduler |
| self.max_params = max_params |
| self.max_latency_ms = max_latency_ms |
| self.max_energy_uj = max_energy_uj |
|
|
| def forward(self, entropy: torch.Tensor, seq_len: int = None, |
| param_factors: dict = None) -> int: |
| """ |
| Compute rank with budget constraints. |
| |
| Args: |
| entropy: Attention entropy. |
| seq_len: Sequence length. |
| param_factors: Dict mapping rank → estimated total parameters. |
| |
| Returns: |
| Budget-constrained rank. |
| """ |
| rank = self.scheduler(entropy, seq_len) |
|
|
| if param_factors and self.max_params: |
| |
| while rank > self.scheduler.r_min: |
| est = param_factors.get(rank, float("inf")) |
| if est <= self.max_params: |
| break |
| rank -= 1 |
|
|
| return rank |
|
|
| def reset(self): |
| self.scheduler.reset() |
|
|