File size: 4,754 Bytes
b9c4adf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Adaptive TT-Rank Scheduler.

Core novelty of Q-TensorFormer: adjusts tensor rank dynamically
based on per-input complexity, estimated via attention entropy.

r(input) = r_min + α × normalized_entropy × (r_max - r_min)

Supports:
  - EMA smoothing to prevent oscillation
  - Budget-capped ranks
  - Deterministic rounding with hysteresis
"""

import torch
import torch.nn as nn
import math


class RankScheduler(nn.Module):
    """
    Attention entropy → TT-rank scheduler.

    Parameters
    ----------
    r_min : int
        Minimum tensor rank (maximum compression).
    r_max : int
        Maximum tensor rank (minimum compression).
    alpha : float
        Sensitivity: how much entropy changes the rank.
        alpha=0 → fixed rank r_min.
        alpha=1 → rank fully spans r_min to r_max.
        alpha=2.0 → aggressive scaling (default).
    smoothing : float
        EMA decay factor (0.9 = smooth, 0 = no history).
    """

    def __init__(self, r_min: int = 2, r_max: int = 8,
                 alpha: float = 2.0, smoothing: float = 0.9):
        super().__init__()
        self.r_min = r_min
        self.r_max = r_max
        self.alpha = alpha
        self.smoothing = smoothing

        self.register_buffer("_ema_entropy", torch.tensor(0.5))
        self.register_buffer("_ema_rank", torch.tensor((r_min + r_max) // 2, dtype=torch.float))
        self.register_buffer("_counter", torch.tensor(0, dtype=torch.long))

        # Optionally learn alpha
        self.learned_alpha = nn.Parameter(torch.tensor(float(alpha)), requires_grad=False)

    def forward(self, entropy: torch.Tensor, seq_len: int = None) -> int:
        """
        Compute rank from attention entropy.

        Args:
            entropy: Scalar or 0-dim tensor (mean attention entropy).
            seq_len: Sequence length for normalization (optional).

        Returns:
            Integer tensor rank.
        """
        if entropy.dim() > 0:
            entropy = entropy.mean()

        # Normalize entropy to [0, 1]
        if seq_len is not None and seq_len > 1:
            norm_factor = math.log(seq_len)
            normalized = torch.clamp(entropy / max(norm_factor, 1e-8), 0.0, 1.0)
        else:
            normalized = torch.clamp(torch.tanh(entropy / 2.0), 0.0, 1.0)

        # EMA smoothing
        self._ema_entropy.mul_(self.smoothing).add_(normalized, alpha=1.0 - self.smoothing)
        smoothed = self._ema_entropy

        # Map to rank: r = r_min + alpha * norm * (r_max - r_min)
        alpha_val = self.learned_alpha.item()
        span = self.r_max - self.r_min
        raw = self.r_min + alpha_val * smoothed.item() * span

        # Round with hysteresis
        self._ema_rank.mul_(0.7).add_(raw, alpha=0.3)
        rank = int(torch.round(self._ema_rank).item())

        # Clamp + counter
        rank = max(self.r_min, min(self.r_max, rank))
        self._counter.add_(1)
        return rank

    def reset(self):
        """Reset EMA state."""
        self._ema_entropy.fill_(0.5)
        self._ema_rank.fill_((self.r_min + self.r_max) / 2.0)
        self._counter.fill_(0)

    @property
    def current_rank(self) -> float:
        return self._ema_rank.item()

    @property
    def current_entropy(self) -> float:
        return self._ema_entropy.item()


class BudgetAwareScheduler(nn.Module):
    """
    Extends RankScheduler with deployment budget constraints.

    Automatically caps tensor rank to meet:
      - Max parameter budget
      - Max latency target
      - Max energy per query
    """

    def __init__(self, scheduler: RankScheduler,
                 max_params: int = None,
                 max_latency_ms: float = None,
                 max_energy_uj: float = None):
        super().__init__()
        self.scheduler = scheduler
        self.max_params = max_params
        self.max_latency_ms = max_latency_ms
        self.max_energy_uj = max_energy_uj

    def forward(self, entropy: torch.Tensor, seq_len: int = None,
                param_factors: dict = None) -> int:
        """
        Compute rank with budget constraints.

        Args:
            entropy: Attention entropy.
            seq_len: Sequence length.
            param_factors: Dict mapping rank → estimated total parameters.

        Returns:
            Budget-constrained rank.
        """
        rank = self.scheduler(entropy, seq_len)

        if param_factors and self.max_params:
            # Find highest rank that meets budget
            while rank > self.scheduler.r_min:
                est = param_factors.get(rank, float("inf"))
                if est <= self.max_params:
                    break
                rank -= 1

        return rank

    def reset(self):
        self.scheduler.reset()