File size: 5,568 Bytes
b9c4adf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
Budget-constrained optimization.
Enforces deployment constraints during training and inference:
- Maximum parameter count
- Maximum inference latency
- Maximum energy per query
The model auto-adjusts tensor ranks to meet these constraints.
"""
import torch
import time
import math
from typing import Optional, Dict
from .config import BudgetConfig
class BudgetTracker:
"""
Tracks whether a model meets deployment budget constraints.
Checks at each validation step:
- Parameter count ≤ max_params
- Estimated latency ≤ max_latency_ms
- Estimated energy ≤ max_energy_per_query
"""
def __init__(self, budget: BudgetConfig):
self.budget = budget
def exceeds_budget(self, metrics: Dict, model_config) -> bool:
"""
Check if current metrics exceed any budget constraint.
Returns True if any constraint is violated.
"""
if self.budget.max_params is not None:
if metrics.get("total_params", 0) > self.budget.max_params:
print(f"[BUDGET] Params exceeded: {metrics['total_params']} > {self.budget.max_params}")
return True
if self.budget.max_latency_ms is not None:
if metrics.get("latency_ms", 0) > self.budget.max_latency_ms:
print(f"[BUDGET] Latency exceeded: {metrics['latency_ms']:.2f} > {self.budget.max_latency_ms}")
return True
if self.budget.max_energy_per_query is not None:
if metrics.get("energy_uj", 0) > self.budget.max_energy_per_query:
print(f"[BUDGET] Energy exceeded: {metrics['energy_uj']:.2f} > {self.budget.max_energy_per_query}")
return True
return False
def estimate_latency(self, model, seq_len: int = 128,
n_warmup: int = 3, n_measure: int = 10) -> float:
"""
Estimate inference latency for a sequence of length seq_len.
Returns mean latency in milliseconds.
"""
device = next(model.parameters()).device
model.eval()
dummy = torch.randint(0, 1000, (1, seq_len)).to(device)
# Warmup
with torch.no_grad():
for _ in range(n_warmup):
_ = model(dummy)
latencies = []
with torch.no_grad():
for _ in range(n_measure):
t0 = time.time()
_ = model(dummy)
if device.type == "cuda":
torch.cuda.synchronize()
latencies.append((time.time() - t0) * 1000)
return sum(latencies) / len(latencies)
def estimate_parameter_budget(self, model, tt_rank: int) -> int:
"""Estimate total parameters at a given TT rank."""
# Approximate: TT params scale ~ O(rank^2)
current = sum(p.numel() for p in model.parameters())
if hasattr(model, "tt_params"):
current_rank = getattr(model, "config", None)
if current_rank:
current_rank = current_rank.tt_rank
else:
return current
# Rough scaling
tt_now = model.tt_params
tt_new = tt_now * (tt_rank / max(current_rank, 1)) ** 2
return int(current - tt_now + tt_new)
return current
class EnergyEstimator:
"""
Energy consumption estimator using FLOPs as proxy.
Approximate conversions (hardware-dependent):
- CPU inference: ~5 pJ/FLOP
- GPU inference (A100): ~0.5 pJ/FLOP
- Edge inference: ~10 pJ/FLOP
"""
# Energy per FLOP in microjoules (μJ)
ENERGY_PER_FLOP = {
"cpu": 5e-6, # 5 pJ → 5e-6 μJ
"gpu_a100": 0.5e-6, # 0.5 pJ → 0.5e-6 μJ
"edge": 10e-6, # 10 pJ → 10e-6 μJ
}
def __init__(self, hardware: str = "cpu"):
self.hardware = hardware
self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)
def estimate(self, model, batch_size: int = 1,
seq_len: int = 128) -> float:
"""
Estimate energy consumption in μJ for one forward pass.
Returns:
Energy in microjoules.
"""
flops = self._estimate_flops(model, batch_size, seq_len)
return flops * self.energy_per_flop
@staticmethod
def _estimate_flops(model, batch_size: int, seq_len: int) -> int:
"""Estimate FLOPs for one forward pass."""
total_params = sum(p.numel() for p in model.parameters())
# Rough: 2 × params × batch × seq_len (multiply-add for each token)
return int(2 * total_params * batch_size * seq_len)
def set_hardware(self, hardware: str):
"""Change hardware target."""
self.hardware = hardware
self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)
def find_feasible_rank(model, budget: BudgetConfig,
param_factors: Dict[int, int] = None) -> int:
"""
Find the maximum TT rank that meets budget constraints.
Args:
model: Model to analyze.
budget: Budget constraints.
param_factors: Dict[rank → estimated_params].
Returns:
Maximum feasible rank.
"""
current_rank = 8 # default
if hasattr(model, "config"):
current_rank = model.config.tt_rank
for rank in range(current_rank, 0, -1):
est_params = param_factors.get(rank, float("inf")) if param_factors else None
if budget.max_params and est_params and est_params > budget.max_params:
continue
return rank
return 1
|