""" Budget-constrained optimization. Enforces deployment constraints during training and inference: - Maximum parameter count - Maximum inference latency - Maximum energy per query The model auto-adjusts tensor ranks to meet these constraints. """ import torch import time import math from typing import Optional, Dict from .config import BudgetConfig class BudgetTracker: """ Tracks whether a model meets deployment budget constraints. Checks at each validation step: - Parameter count ≤ max_params - Estimated latency ≤ max_latency_ms - Estimated energy ≤ max_energy_per_query """ def __init__(self, budget: BudgetConfig): self.budget = budget def exceeds_budget(self, metrics: Dict, model_config) -> bool: """ Check if current metrics exceed any budget constraint. Returns True if any constraint is violated. """ if self.budget.max_params is not None: if metrics.get("total_params", 0) > self.budget.max_params: print(f"[BUDGET] Params exceeded: {metrics['total_params']} > {self.budget.max_params}") return True if self.budget.max_latency_ms is not None: if metrics.get("latency_ms", 0) > self.budget.max_latency_ms: print(f"[BUDGET] Latency exceeded: {metrics['latency_ms']:.2f} > {self.budget.max_latency_ms}") return True if self.budget.max_energy_per_query is not None: if metrics.get("energy_uj", 0) > self.budget.max_energy_per_query: print(f"[BUDGET] Energy exceeded: {metrics['energy_uj']:.2f} > {self.budget.max_energy_per_query}") return True return False def estimate_latency(self, model, seq_len: int = 128, n_warmup: int = 3, n_measure: int = 10) -> float: """ Estimate inference latency for a sequence of length seq_len. Returns mean latency in milliseconds. """ device = next(model.parameters()).device model.eval() dummy = torch.randint(0, 1000, (1, seq_len)).to(device) # Warmup with torch.no_grad(): for _ in range(n_warmup): _ = model(dummy) latencies = [] with torch.no_grad(): for _ in range(n_measure): t0 = time.time() _ = model(dummy) if device.type == "cuda": torch.cuda.synchronize() latencies.append((time.time() - t0) * 1000) return sum(latencies) / len(latencies) def estimate_parameter_budget(self, model, tt_rank: int) -> int: """Estimate total parameters at a given TT rank.""" # Approximate: TT params scale ~ O(rank^2) current = sum(p.numel() for p in model.parameters()) if hasattr(model, "tt_params"): current_rank = getattr(model, "config", None) if current_rank: current_rank = current_rank.tt_rank else: return current # Rough scaling tt_now = model.tt_params tt_new = tt_now * (tt_rank / max(current_rank, 1)) ** 2 return int(current - tt_now + tt_new) return current class EnergyEstimator: """ Energy consumption estimator using FLOPs as proxy. Approximate conversions (hardware-dependent): - CPU inference: ~5 pJ/FLOP - GPU inference (A100): ~0.5 pJ/FLOP - Edge inference: ~10 pJ/FLOP """ # Energy per FLOP in microjoules (μJ) ENERGY_PER_FLOP = { "cpu": 5e-6, # 5 pJ → 5e-6 μJ "gpu_a100": 0.5e-6, # 0.5 pJ → 0.5e-6 μJ "edge": 10e-6, # 10 pJ → 10e-6 μJ } def __init__(self, hardware: str = "cpu"): self.hardware = hardware self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6) def estimate(self, model, batch_size: int = 1, seq_len: int = 128) -> float: """ Estimate energy consumption in μJ for one forward pass. Returns: Energy in microjoules. """ flops = self._estimate_flops(model, batch_size, seq_len) return flops * self.energy_per_flop @staticmethod def _estimate_flops(model, batch_size: int, seq_len: int) -> int: """Estimate FLOPs for one forward pass.""" total_params = sum(p.numel() for p in model.parameters()) # Rough: 2 × params × batch × seq_len (multiply-add for each token) return int(2 * total_params * batch_size * seq_len) def set_hardware(self, hardware: str): """Change hardware target.""" self.hardware = hardware self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6) def find_feasible_rank(model, budget: BudgetConfig, param_factors: Dict[int, int] = None) -> int: """ Find the maximum TT rank that meets budget constraints. Args: model: Model to analyze. budget: Budget constraints. param_factors: Dict[rank → estimated_params]. Returns: Maximum feasible rank. """ current_rank = 8 # default if hasattr(model, "config"): current_rank = model.config.tt_rank for rank in range(current_rank, 0, -1): est_params = param_factors.get(rank, float("inf")) if param_factors else None if budget.max_params and est_params and est_params > budget.max_params: continue return rank return 1