File size: 4,754 Bytes
b9c4adf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
Adaptive TT-Rank Scheduler.
Core novelty of Q-TensorFormer: adjusts tensor rank dynamically
based on per-input complexity, estimated via attention entropy.
r(input) = r_min + α × normalized_entropy × (r_max - r_min)
Supports:
- EMA smoothing to prevent oscillation
- Budget-capped ranks
- Deterministic rounding with hysteresis
"""
import torch
import torch.nn as nn
import math
class RankScheduler(nn.Module):
"""
Attention entropy → TT-rank scheduler.
Parameters
----------
r_min : int
Minimum tensor rank (maximum compression).
r_max : int
Maximum tensor rank (minimum compression).
alpha : float
Sensitivity: how much entropy changes the rank.
alpha=0 → fixed rank r_min.
alpha=1 → rank fully spans r_min to r_max.
alpha=2.0 → aggressive scaling (default).
smoothing : float
EMA decay factor (0.9 = smooth, 0 = no history).
"""
def __init__(self, r_min: int = 2, r_max: int = 8,
alpha: float = 2.0, smoothing: float = 0.9):
super().__init__()
self.r_min = r_min
self.r_max = r_max
self.alpha = alpha
self.smoothing = smoothing
self.register_buffer("_ema_entropy", torch.tensor(0.5))
self.register_buffer("_ema_rank", torch.tensor((r_min + r_max) // 2, dtype=torch.float))
self.register_buffer("_counter", torch.tensor(0, dtype=torch.long))
# Optionally learn alpha
self.learned_alpha = nn.Parameter(torch.tensor(float(alpha)), requires_grad=False)
def forward(self, entropy: torch.Tensor, seq_len: int = None) -> int:
"""
Compute rank from attention entropy.
Args:
entropy: Scalar or 0-dim tensor (mean attention entropy).
seq_len: Sequence length for normalization (optional).
Returns:
Integer tensor rank.
"""
if entropy.dim() > 0:
entropy = entropy.mean()
# Normalize entropy to [0, 1]
if seq_len is not None and seq_len > 1:
norm_factor = math.log(seq_len)
normalized = torch.clamp(entropy / max(norm_factor, 1e-8), 0.0, 1.0)
else:
normalized = torch.clamp(torch.tanh(entropy / 2.0), 0.0, 1.0)
# EMA smoothing
self._ema_entropy.mul_(self.smoothing).add_(normalized, alpha=1.0 - self.smoothing)
smoothed = self._ema_entropy
# Map to rank: r = r_min + alpha * norm * (r_max - r_min)
alpha_val = self.learned_alpha.item()
span = self.r_max - self.r_min
raw = self.r_min + alpha_val * smoothed.item() * span
# Round with hysteresis
self._ema_rank.mul_(0.7).add_(raw, alpha=0.3)
rank = int(torch.round(self._ema_rank).item())
# Clamp + counter
rank = max(self.r_min, min(self.r_max, rank))
self._counter.add_(1)
return rank
def reset(self):
"""Reset EMA state."""
self._ema_entropy.fill_(0.5)
self._ema_rank.fill_((self.r_min + self.r_max) / 2.0)
self._counter.fill_(0)
@property
def current_rank(self) -> float:
return self._ema_rank.item()
@property
def current_entropy(self) -> float:
return self._ema_entropy.item()
class BudgetAwareScheduler(nn.Module):
"""
Extends RankScheduler with deployment budget constraints.
Automatically caps tensor rank to meet:
- Max parameter budget
- Max latency target
- Max energy per query
"""
def __init__(self, scheduler: RankScheduler,
max_params: int = None,
max_latency_ms: float = None,
max_energy_uj: float = None):
super().__init__()
self.scheduler = scheduler
self.max_params = max_params
self.max_latency_ms = max_latency_ms
self.max_energy_uj = max_energy_uj
def forward(self, entropy: torch.Tensor, seq_len: int = None,
param_factors: dict = None) -> int:
"""
Compute rank with budget constraints.
Args:
entropy: Attention entropy.
seq_len: Sequence length.
param_factors: Dict mapping rank → estimated total parameters.
Returns:
Budget-constrained rank.
"""
rank = self.scheduler(entropy, seq_len)
if param_factors and self.max_params:
# Find highest rank that meets budget
while rank > self.scheduler.r_min:
est = param_factors.get(rank, float("inf"))
if est <= self.max_params:
break
rank -= 1
return rank
def reset(self):
self.scheduler.reset()
|