File size: 5,568 Bytes
b9c4adf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Budget-constrained optimization.

Enforces deployment constraints during training and inference:
  - Maximum parameter count
  - Maximum inference latency
  - Maximum energy per query

The model auto-adjusts tensor ranks to meet these constraints.
"""

import torch
import time
import math
from typing import Optional, Dict
from .config import BudgetConfig


class BudgetTracker:
    """
    Tracks whether a model meets deployment budget constraints.

    Checks at each validation step:
      - Parameter count ≤ max_params
      - Estimated latency ≤ max_latency_ms
      - Estimated energy ≤ max_energy_per_query
    """

    def __init__(self, budget: BudgetConfig):
        self.budget = budget

    def exceeds_budget(self, metrics: Dict, model_config) -> bool:
        """
        Check if current metrics exceed any budget constraint.

        Returns True if any constraint is violated.
        """
        if self.budget.max_params is not None:
            if metrics.get("total_params", 0) > self.budget.max_params:
                print(f"[BUDGET] Params exceeded: {metrics['total_params']} > {self.budget.max_params}")
                return True

        if self.budget.max_latency_ms is not None:
            if metrics.get("latency_ms", 0) > self.budget.max_latency_ms:
                print(f"[BUDGET] Latency exceeded: {metrics['latency_ms']:.2f} > {self.budget.max_latency_ms}")
                return True

        if self.budget.max_energy_per_query is not None:
            if metrics.get("energy_uj", 0) > self.budget.max_energy_per_query:
                print(f"[BUDGET] Energy exceeded: {metrics['energy_uj']:.2f} > {self.budget.max_energy_per_query}")
                return True

        return False

    def estimate_latency(self, model, seq_len: int = 128,
                         n_warmup: int = 3, n_measure: int = 10) -> float:
        """
        Estimate inference latency for a sequence of length seq_len.

        Returns mean latency in milliseconds.
        """
        device = next(model.parameters()).device
        model.eval()

        dummy = torch.randint(0, 1000, (1, seq_len)).to(device)

        # Warmup
        with torch.no_grad():
            for _ in range(n_warmup):
                _ = model(dummy)

        latencies = []
        with torch.no_grad():
            for _ in range(n_measure):
                t0 = time.time()
                _ = model(dummy)
                if device.type == "cuda":
                    torch.cuda.synchronize()
                latencies.append((time.time() - t0) * 1000)

        return sum(latencies) / len(latencies)

    def estimate_parameter_budget(self, model, tt_rank: int) -> int:
        """Estimate total parameters at a given TT rank."""
        # Approximate: TT params scale ~ O(rank^2)
        current = sum(p.numel() for p in model.parameters())
        if hasattr(model, "tt_params"):
            current_rank = getattr(model, "config", None)
            if current_rank:
                current_rank = current_rank.tt_rank
            else:
                return current
            # Rough scaling
            tt_now = model.tt_params
            tt_new = tt_now * (tt_rank / max(current_rank, 1)) ** 2
            return int(current - tt_now + tt_new)
        return current


class EnergyEstimator:
    """
    Energy consumption estimator using FLOPs as proxy.

    Approximate conversions (hardware-dependent):
      - CPU inference: ~5 pJ/FLOP
      - GPU inference (A100): ~0.5 pJ/FLOP
      - Edge inference: ~10 pJ/FLOP
    """

    # Energy per FLOP in microjoules (μJ)
    ENERGY_PER_FLOP = {
        "cpu": 5e-6,      # 5 pJ → 5e-6 μJ
        "gpu_a100": 0.5e-6,  # 0.5 pJ → 0.5e-6 μJ
        "edge": 10e-6,    # 10 pJ → 10e-6 μJ
    }

    def __init__(self, hardware: str = "cpu"):
        self.hardware = hardware
        self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)

    def estimate(self, model, batch_size: int = 1,
                 seq_len: int = 128) -> float:
        """
        Estimate energy consumption in μJ for one forward pass.

        Returns:
            Energy in microjoules.
        """
        flops = self._estimate_flops(model, batch_size, seq_len)
        return flops * self.energy_per_flop

    @staticmethod
    def _estimate_flops(model, batch_size: int, seq_len: int) -> int:
        """Estimate FLOPs for one forward pass."""
        total_params = sum(p.numel() for p in model.parameters())
        # Rough: 2 × params × batch × seq_len (multiply-add for each token)
        return int(2 * total_params * batch_size * seq_len)

    def set_hardware(self, hardware: str):
        """Change hardware target."""
        self.hardware = hardware
        self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)


def find_feasible_rank(model, budget: BudgetConfig,
                       param_factors: Dict[int, int] = None) -> int:
    """
    Find the maximum TT rank that meets budget constraints.

    Args:
        model: Model to analyze.
        budget: Budget constraints.
        param_factors: Dict[rank → estimated_params].

    Returns:
        Maximum feasible rank.
    """
    current_rank = 8  # default
    if hasattr(model, "config"):
        current_rank = model.config.tt_rank

    for rank in range(current_rank, 0, -1):
        est_params = param_factors.get(rank, float("inf")) if param_factors else None
        if budget.max_params and est_params and est_params > budget.max_params:
            continue
        return rank
    return 1