twill-swp-ws / twill /cost_normalization.py
AshenNav's picture
Upload twill/cost_normalization.py with huggingface_hub
744cb86 verified
"""
Cost Normalization (Section 5.2)
Renders the ZLP and SMT problems tractable by finding smaller integer cycle counts
whose ratios closely approximate the original cycle count ratios.
The optimization problem:
Given original cycle counts C = [c1, c2, ..., cn],
find new counts C' = [c'1, c'2, ..., c'n] such that:
∀i,j: -F ≤ C[i]·C'[j] - C[j]·C'[i] ≤ F (bounded ratio change)
1 ≤ Σ C'[i] ≤ U (avoid zero solution)
minimize F (minimize distortion)
Uses PuLP (CBC solver) since SCIP is not freely available in Python.
The paper uses SCIP but notes CBC also works (just slower on some instances).
"""
import pulp
import numpy as np
from typing import List, Dict, Tuple, Optional
def normalize_costs(
original_costs: Dict[str, int],
U: int = 300,
solver_time_limit: int = 60,
) -> Tuple[Dict[str, int], float]:
"""Normalize cycle counts to keep ratios but reduce magnitudes.
Args:
original_costs: Dict mapping label -> original cycle count
U: Upper bound on sum of new costs (controls resolution vs. speed tradeoff)
solver_time_limit: Time limit in seconds for the solver
Returns:
Tuple of (normalized_costs dict, distortion F value)
Example:
>>> costs = {"GEMM": 1000, "EXP": 1000, "TMA_LOAD": 500}
>>> norm, F = normalize_costs(costs, U=10)
>>> # norm might be {"GEMM": 2, "EXP": 2, "TMA_LOAD": 1}
"""
labels = list(original_costs.keys())
C = [original_costs[l] for l in labels]
n = len(C)
if n == 0:
return {}, 0.0
# Trivial case: all costs equal
if len(set(C)) == 1:
# All equal -> just set them all to 1
return {l: 1 for l in labels}, 0.0
# If all costs already small enough, return as-is
if sum(C) <= U:
return dict(original_costs), 0.0
# Set up the ILP problem
prob = pulp.LpProblem("CostNormalization", pulp.LpMinimize)
# Decision variables
# C'[i] for each instruction
c_prime = [pulp.LpVariable(f"c_{i}", lowBound=1, cat='Integer') for i in range(n)]
# F: the distortion bound
F = pulp.LpVariable("F", lowBound=0, cat='Integer')
# Objective: minimize F
prob += F
# Constraint: 1 ≤ Σ C'[i] ≤ U
prob += pulp.lpSum(c_prime) >= 1
prob += pulp.lpSum(c_prime) <= U
# Ratio constraints: ∀i,j: -F ≤ C[i]·C'[j] - C[j]·C'[i] ≤ F
for i in range(n):
for j in range(i + 1, n):
# C[i]*C'[j] - C[j]*C'[i] ≤ F
prob += C[i] * c_prime[j] - C[j] * c_prime[i] <= F
# -(C[i]*C'[j] - C[j]*C'[i]) ≤ F => C[j]*C'[i] - C[i]*C'[j] ≤ F
prob += C[j] * c_prime[i] - C[i] * c_prime[j] <= F
# Solve
solver = pulp.PULP_CBC_CMD(msg=0, timeLimit=solver_time_limit)
status = prob.solve(solver)
if status != pulp.constants.LpStatusOptimal:
# Fallback: proportional scaling
max_c = max(C)
scale = U / (n * max_c) if max_c > 0 else 1.0
normalized = {l: max(1, int(round(c * scale))) for l, c in zip(labels, C)}
return normalized, float('inf')
normalized = {}
for i, label in enumerate(labels):
val = int(round(pulp.value(c_prime[i])))
normalized[label] = max(1, val)
distortion = pulp.value(F)
return normalized, distortion
def apply_normalization(
original_costs: Dict[str, int],
normalized_costs: Dict[str, int],
edge_delays: List[Tuple[str, str, int]],
) -> List[Tuple[str, str, int]]:
"""Scale edge delays according to normalized instruction costs.
When instruction costs are normalized, edge delays (which depend on
the execution time of the source instruction) must be rescaled proportionally.
Args:
original_costs: Original cycle counts per instruction
normalized_costs: Normalized cycle counts per instruction
edge_delays: List of (src, dst, original_delay) tuples
Returns:
List of (src, dst, normalized_delay) tuples
"""
result = []
for src, dst, d in edge_delays:
if src in original_costs and original_costs[src] > 0:
# Scale delay proportionally to source instruction's cost ratio
ratio = normalized_costs[src] / original_costs[src]
new_d = max(1, int(round(d * ratio)))
else:
new_d = d
result.append((src, dst, new_d))
return result