Q-TensorFormer / src /budget.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
"""
Budget-constrained optimization.
Enforces deployment constraints during training and inference:
- Maximum parameter count
- Maximum inference latency
- Maximum energy per query
The model auto-adjusts tensor ranks to meet these constraints.
"""
import torch
import time
import math
from typing import Optional, Dict
from .config import BudgetConfig
class BudgetTracker:
"""
Tracks whether a model meets deployment budget constraints.
Checks at each validation step:
- Parameter count ≤ max_params
- Estimated latency ≤ max_latency_ms
- Estimated energy ≤ max_energy_per_query
"""
def __init__(self, budget: BudgetConfig):
self.budget = budget
def exceeds_budget(self, metrics: Dict, model_config) -> bool:
"""
Check if current metrics exceed any budget constraint.
Returns True if any constraint is violated.
"""
if self.budget.max_params is not None:
if metrics.get("total_params", 0) > self.budget.max_params:
print(f"[BUDGET] Params exceeded: {metrics['total_params']} > {self.budget.max_params}")
return True
if self.budget.max_latency_ms is not None:
if metrics.get("latency_ms", 0) > self.budget.max_latency_ms:
print(f"[BUDGET] Latency exceeded: {metrics['latency_ms']:.2f} > {self.budget.max_latency_ms}")
return True
if self.budget.max_energy_per_query is not None:
if metrics.get("energy_uj", 0) > self.budget.max_energy_per_query:
print(f"[BUDGET] Energy exceeded: {metrics['energy_uj']:.2f} > {self.budget.max_energy_per_query}")
return True
return False
def estimate_latency(self, model, seq_len: int = 128,
n_warmup: int = 3, n_measure: int = 10) -> float:
"""
Estimate inference latency for a sequence of length seq_len.
Returns mean latency in milliseconds.
"""
device = next(model.parameters()).device
model.eval()
dummy = torch.randint(0, 1000, (1, seq_len)).to(device)
# Warmup
with torch.no_grad():
for _ in range(n_warmup):
_ = model(dummy)
latencies = []
with torch.no_grad():
for _ in range(n_measure):
t0 = time.time()
_ = model(dummy)
if device.type == "cuda":
torch.cuda.synchronize()
latencies.append((time.time() - t0) * 1000)
return sum(latencies) / len(latencies)
def estimate_parameter_budget(self, model, tt_rank: int) -> int:
"""Estimate total parameters at a given TT rank."""
# Approximate: TT params scale ~ O(rank^2)
current = sum(p.numel() for p in model.parameters())
if hasattr(model, "tt_params"):
current_rank = getattr(model, "config", None)
if current_rank:
current_rank = current_rank.tt_rank
else:
return current
# Rough scaling
tt_now = model.tt_params
tt_new = tt_now * (tt_rank / max(current_rank, 1)) ** 2
return int(current - tt_now + tt_new)
return current
class EnergyEstimator:
"""
Energy consumption estimator using FLOPs as proxy.
Approximate conversions (hardware-dependent):
- CPU inference: ~5 pJ/FLOP
- GPU inference (A100): ~0.5 pJ/FLOP
- Edge inference: ~10 pJ/FLOP
"""
# Energy per FLOP in microjoules (μJ)
ENERGY_PER_FLOP = {
"cpu": 5e-6, # 5 pJ → 5e-6 μJ
"gpu_a100": 0.5e-6, # 0.5 pJ → 0.5e-6 μJ
"edge": 10e-6, # 10 pJ → 10e-6 μJ
}
def __init__(self, hardware: str = "cpu"):
self.hardware = hardware
self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)
def estimate(self, model, batch_size: int = 1,
seq_len: int = 128) -> float:
"""
Estimate energy consumption in μJ for one forward pass.
Returns:
Energy in microjoules.
"""
flops = self._estimate_flops(model, batch_size, seq_len)
return flops * self.energy_per_flop
@staticmethod
def _estimate_flops(model, batch_size: int, seq_len: int) -> int:
"""Estimate FLOPs for one forward pass."""
total_params = sum(p.numel() for p in model.parameters())
# Rough: 2 × params × batch × seq_len (multiply-add for each token)
return int(2 * total_params * batch_size * seq_len)
def set_hardware(self, hardware: str):
"""Change hardware target."""
self.hardware = hardware
self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)
def find_feasible_rank(model, budget: BudgetConfig,
param_factors: Dict[int, int] = None) -> int:
"""
Find the maximum TT rank that meets budget constraints.
Args:
model: Model to analyze.
budget: Budget constraints.
param_factors: Dict[rank → estimated_params].
Returns:
Maximum feasible rank.
"""
current_rank = 8 # default
if hasattr(model, "config"):
current_rank = model.config.tt_rank
for rank in range(current_rank, 0, -1):
est_params = param_factors.get(rank, float("inf")) if param_factors else None
if budget.max_params and est_params and est_params > budget.max_params:
continue
return rank
return 1