""" Adaptive TT-Rank Scheduler. Core novelty of Q-TensorFormer: adjusts tensor rank dynamically based on per-input complexity, estimated via attention entropy. r(input) = r_min + α × normalized_entropy × (r_max - r_min) Supports: - EMA smoothing to prevent oscillation - Budget-capped ranks - Deterministic rounding with hysteresis """ import torch import torch.nn as nn import math class RankScheduler(nn.Module): """ Attention entropy → TT-rank scheduler. Parameters ---------- r_min : int Minimum tensor rank (maximum compression). r_max : int Maximum tensor rank (minimum compression). alpha : float Sensitivity: how much entropy changes the rank. alpha=0 → fixed rank r_min. alpha=1 → rank fully spans r_min to r_max. alpha=2.0 → aggressive scaling (default). smoothing : float EMA decay factor (0.9 = smooth, 0 = no history). """ def __init__(self, r_min: int = 2, r_max: int = 8, alpha: float = 2.0, smoothing: float = 0.9): super().__init__() self.r_min = r_min self.r_max = r_max self.alpha = alpha self.smoothing = smoothing self.register_buffer("_ema_entropy", torch.tensor(0.5)) self.register_buffer("_ema_rank", torch.tensor((r_min + r_max) // 2, dtype=torch.float)) self.register_buffer("_counter", torch.tensor(0, dtype=torch.long)) # Optionally learn alpha self.learned_alpha = nn.Parameter(torch.tensor(float(alpha)), requires_grad=False) def forward(self, entropy: torch.Tensor, seq_len: int = None) -> int: """ Compute rank from attention entropy. Args: entropy: Scalar or 0-dim tensor (mean attention entropy). seq_len: Sequence length for normalization (optional). Returns: Integer tensor rank. """ if entropy.dim() > 0: entropy = entropy.mean() # Normalize entropy to [0, 1] if seq_len is not None and seq_len > 1: norm_factor = math.log(seq_len) normalized = torch.clamp(entropy / max(norm_factor, 1e-8), 0.0, 1.0) else: normalized = torch.clamp(torch.tanh(entropy / 2.0), 0.0, 1.0) # EMA smoothing self._ema_entropy.mul_(self.smoothing).add_(normalized, alpha=1.0 - self.smoothing) smoothed = self._ema_entropy # Map to rank: r = r_min + alpha * norm * (r_max - r_min) alpha_val = self.learned_alpha.item() span = self.r_max - self.r_min raw = self.r_min + alpha_val * smoothed.item() * span # Round with hysteresis self._ema_rank.mul_(0.7).add_(raw, alpha=0.3) rank = int(torch.round(self._ema_rank).item()) # Clamp + counter rank = max(self.r_min, min(self.r_max, rank)) self._counter.add_(1) return rank def reset(self): """Reset EMA state.""" self._ema_entropy.fill_(0.5) self._ema_rank.fill_((self.r_min + self.r_max) / 2.0) self._counter.fill_(0) @property def current_rank(self) -> float: return self._ema_rank.item() @property def current_entropy(self) -> float: return self._ema_entropy.item() class BudgetAwareScheduler(nn.Module): """ Extends RankScheduler with deployment budget constraints. Automatically caps tensor rank to meet: - Max parameter budget - Max latency target - Max energy per query """ def __init__(self, scheduler: RankScheduler, max_params: int = None, max_latency_ms: float = None, max_energy_uj: float = None): super().__init__() self.scheduler = scheduler self.max_params = max_params self.max_latency_ms = max_latency_ms self.max_energy_uj = max_energy_uj def forward(self, entropy: torch.Tensor, seq_len: int = None, param_factors: dict = None) -> int: """ Compute rank with budget constraints. Args: entropy: Attention entropy. seq_len: Sequence length. param_factors: Dict mapping rank → estimated total parameters. Returns: Budget-constrained rank. """ rank = self.scheduler(entropy, seq_len) if param_factors and self.max_params: # Find highest rank that meets budget while rank > self.scheduler.r_min: est = param_factors.get(rank, float("inf")) if est <= self.max_params: break rank -= 1 return rank def reset(self): self.scheduler.reset()