| """ |
| Budget-constrained optimization. |
| |
| Enforces deployment constraints during training and inference: |
| - Maximum parameter count |
| - Maximum inference latency |
| - Maximum energy per query |
| |
| The model auto-adjusts tensor ranks to meet these constraints. |
| """ |
|
|
| import torch |
| import time |
| import math |
| from typing import Optional, Dict |
| from .config import BudgetConfig |
|
|
|
|
| class BudgetTracker: |
| """ |
| Tracks whether a model meets deployment budget constraints. |
| |
| Checks at each validation step: |
| - Parameter count ≤ max_params |
| - Estimated latency ≤ max_latency_ms |
| - Estimated energy ≤ max_energy_per_query |
| """ |
|
|
| def __init__(self, budget: BudgetConfig): |
| self.budget = budget |
|
|
| def exceeds_budget(self, metrics: Dict, model_config) -> bool: |
| """ |
| Check if current metrics exceed any budget constraint. |
| |
| Returns True if any constraint is violated. |
| """ |
| if self.budget.max_params is not None: |
| if metrics.get("total_params", 0) > self.budget.max_params: |
| print(f"[BUDGET] Params exceeded: {metrics['total_params']} > {self.budget.max_params}") |
| return True |
|
|
| if self.budget.max_latency_ms is not None: |
| if metrics.get("latency_ms", 0) > self.budget.max_latency_ms: |
| print(f"[BUDGET] Latency exceeded: {metrics['latency_ms']:.2f} > {self.budget.max_latency_ms}") |
| return True |
|
|
| if self.budget.max_energy_per_query is not None: |
| if metrics.get("energy_uj", 0) > self.budget.max_energy_per_query: |
| print(f"[BUDGET] Energy exceeded: {metrics['energy_uj']:.2f} > {self.budget.max_energy_per_query}") |
| return True |
|
|
| return False |
|
|
| def estimate_latency(self, model, seq_len: int = 128, |
| n_warmup: int = 3, n_measure: int = 10) -> float: |
| """ |
| Estimate inference latency for a sequence of length seq_len. |
| |
| Returns mean latency in milliseconds. |
| """ |
| device = next(model.parameters()).device |
| model.eval() |
|
|
| dummy = torch.randint(0, 1000, (1, seq_len)).to(device) |
|
|
| |
| with torch.no_grad(): |
| for _ in range(n_warmup): |
| _ = model(dummy) |
|
|
| latencies = [] |
| with torch.no_grad(): |
| for _ in range(n_measure): |
| t0 = time.time() |
| _ = model(dummy) |
| if device.type == "cuda": |
| torch.cuda.synchronize() |
| latencies.append((time.time() - t0) * 1000) |
|
|
| return sum(latencies) / len(latencies) |
|
|
| def estimate_parameter_budget(self, model, tt_rank: int) -> int: |
| """Estimate total parameters at a given TT rank.""" |
| |
| current = sum(p.numel() for p in model.parameters()) |
| if hasattr(model, "tt_params"): |
| current_rank = getattr(model, "config", None) |
| if current_rank: |
| current_rank = current_rank.tt_rank |
| else: |
| return current |
| |
| tt_now = model.tt_params |
| tt_new = tt_now * (tt_rank / max(current_rank, 1)) ** 2 |
| return int(current - tt_now + tt_new) |
| return current |
|
|
|
|
| class EnergyEstimator: |
| """ |
| Energy consumption estimator using FLOPs as proxy. |
| |
| Approximate conversions (hardware-dependent): |
| - CPU inference: ~5 pJ/FLOP |
| - GPU inference (A100): ~0.5 pJ/FLOP |
| - Edge inference: ~10 pJ/FLOP |
| """ |
|
|
| |
| ENERGY_PER_FLOP = { |
| "cpu": 5e-6, |
| "gpu_a100": 0.5e-6, |
| "edge": 10e-6, |
| } |
|
|
| def __init__(self, hardware: str = "cpu"): |
| self.hardware = hardware |
| self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6) |
|
|
| def estimate(self, model, batch_size: int = 1, |
| seq_len: int = 128) -> float: |
| """ |
| Estimate energy consumption in μJ for one forward pass. |
| |
| Returns: |
| Energy in microjoules. |
| """ |
| flops = self._estimate_flops(model, batch_size, seq_len) |
| return flops * self.energy_per_flop |
|
|
| @staticmethod |
| def _estimate_flops(model, batch_size: int, seq_len: int) -> int: |
| """Estimate FLOPs for one forward pass.""" |
| total_params = sum(p.numel() for p in model.parameters()) |
| |
| return int(2 * total_params * batch_size * seq_len) |
|
|
| def set_hardware(self, hardware: str): |
| """Change hardware target.""" |
| self.hardware = hardware |
| self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6) |
|
|
|
|
| def find_feasible_rank(model, budget: BudgetConfig, |
| param_factors: Dict[int, int] = None) -> int: |
| """ |
| Find the maximum TT rank that meets budget constraints. |
| |
| Args: |
| model: Model to analyze. |
| budget: Budget constraints. |
| param_factors: Dict[rank → estimated_params]. |
| |
| Returns: |
| Maximum feasible rank. |
| """ |
| current_rank = 8 |
| if hasattr(model, "config"): |
| current_rank = model.config.tt_rank |
|
|
| for rank in range(current_rank, 0, -1): |
| est_params = param_factors.get(rank, float("inf")) if param_factors else None |
| if budget.max_params and est_params and est_params > budget.max_params: |
| continue |
| return rank |
| return 1 |
|
|