Q-TensorFormer / src /scheduler.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
"""
Adaptive TT-Rank Scheduler.
Core novelty of Q-TensorFormer: adjusts tensor rank dynamically
based on per-input complexity, estimated via attention entropy.
r(input) = r_min + α × normalized_entropy × (r_max - r_min)
Supports:
- EMA smoothing to prevent oscillation
- Budget-capped ranks
- Deterministic rounding with hysteresis
"""
import torch
import torch.nn as nn
import math
class RankScheduler(nn.Module):
"""
Attention entropy → TT-rank scheduler.
Parameters
----------
r_min : int
Minimum tensor rank (maximum compression).
r_max : int
Maximum tensor rank (minimum compression).
alpha : float
Sensitivity: how much entropy changes the rank.
alpha=0 → fixed rank r_min.
alpha=1 → rank fully spans r_min to r_max.
alpha=2.0 → aggressive scaling (default).
smoothing : float
EMA decay factor (0.9 = smooth, 0 = no history).
"""
def __init__(self, r_min: int = 2, r_max: int = 8,
alpha: float = 2.0, smoothing: float = 0.9):
super().__init__()
self.r_min = r_min
self.r_max = r_max
self.alpha = alpha
self.smoothing = smoothing
self.register_buffer("_ema_entropy", torch.tensor(0.5))
self.register_buffer("_ema_rank", torch.tensor((r_min + r_max) // 2, dtype=torch.float))
self.register_buffer("_counter", torch.tensor(0, dtype=torch.long))
# Optionally learn alpha
self.learned_alpha = nn.Parameter(torch.tensor(float(alpha)), requires_grad=False)
def forward(self, entropy: torch.Tensor, seq_len: int = None) -> int:
"""
Compute rank from attention entropy.
Args:
entropy: Scalar or 0-dim tensor (mean attention entropy).
seq_len: Sequence length for normalization (optional).
Returns:
Integer tensor rank.
"""
if entropy.dim() > 0:
entropy = entropy.mean()
# Normalize entropy to [0, 1]
if seq_len is not None and seq_len > 1:
norm_factor = math.log(seq_len)
normalized = torch.clamp(entropy / max(norm_factor, 1e-8), 0.0, 1.0)
else:
normalized = torch.clamp(torch.tanh(entropy / 2.0), 0.0, 1.0)
# EMA smoothing
self._ema_entropy.mul_(self.smoothing).add_(normalized, alpha=1.0 - self.smoothing)
smoothed = self._ema_entropy
# Map to rank: r = r_min + alpha * norm * (r_max - r_min)
alpha_val = self.learned_alpha.item()
span = self.r_max - self.r_min
raw = self.r_min + alpha_val * smoothed.item() * span
# Round with hysteresis
self._ema_rank.mul_(0.7).add_(raw, alpha=0.3)
rank = int(torch.round(self._ema_rank).item())
# Clamp + counter
rank = max(self.r_min, min(self.r_max, rank))
self._counter.add_(1)
return rank
def reset(self):
"""Reset EMA state."""
self._ema_entropy.fill_(0.5)
self._ema_rank.fill_((self.r_min + self.r_max) / 2.0)
self._counter.fill_(0)
@property
def current_rank(self) -> float:
return self._ema_rank.item()
@property
def current_entropy(self) -> float:
return self._ema_entropy.item()
class BudgetAwareScheduler(nn.Module):
"""
Extends RankScheduler with deployment budget constraints.
Automatically caps tensor rank to meet:
- Max parameter budget
- Max latency target
- Max energy per query
"""
def __init__(self, scheduler: RankScheduler,
max_params: int = None,
max_latency_ms: float = None,
max_energy_uj: float = None):
super().__init__()
self.scheduler = scheduler
self.max_params = max_params
self.max_latency_ms = max_latency_ms
self.max_energy_uj = max_energy_uj
def forward(self, entropy: torch.Tensor, seq_len: int = None,
param_factors: dict = None) -> int:
"""
Compute rank with budget constraints.
Args:
entropy: Attention entropy.
seq_len: Sequence length.
param_factors: Dict mapping rank → estimated total parameters.
Returns:
Budget-constrained rank.
"""
rank = self.scheduler(entropy, seq_len)
if param_factors and self.max_params:
# Find highest rank that meets budget
while rank > self.scheduler.r_min:
est = param_factors.get(rank, float("inf"))
if est <= self.max_params:
break
rank -= 1
return rank
def reset(self):
self.scheduler.reset()