AshenNav commited on
Commit
744cb86
·
verified ·
1 Parent(s): a3e198c

Upload twill/cost_normalization.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. twill/cost_normalization.py +132 -0
twill/cost_normalization.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cost Normalization (Section 5.2)
3
+
4
+ Renders the ZLP and SMT problems tractable by finding smaller integer cycle counts
5
+ whose ratios closely approximate the original cycle count ratios.
6
+
7
+ The optimization problem:
8
+ Given original cycle counts C = [c1, c2, ..., cn],
9
+ find new counts C' = [c'1, c'2, ..., c'n] such that:
10
+
11
+ ∀i,j: -F ≤ C[i]·C'[j] - C[j]·C'[i] ≤ F (bounded ratio change)
12
+ 1 ≤ Σ C'[i] ≤ U (avoid zero solution)
13
+ minimize F (minimize distortion)
14
+
15
+ Uses PuLP (CBC solver) since SCIP is not freely available in Python.
16
+ The paper uses SCIP but notes CBC also works (just slower on some instances).
17
+ """
18
+
19
+ import pulp
20
+ import numpy as np
21
+ from typing import List, Dict, Tuple, Optional
22
+
23
+
24
+ def normalize_costs(
25
+ original_costs: Dict[str, int],
26
+ U: int = 300,
27
+ solver_time_limit: int = 60,
28
+ ) -> Tuple[Dict[str, int], float]:
29
+ """Normalize cycle counts to keep ratios but reduce magnitudes.
30
+
31
+ Args:
32
+ original_costs: Dict mapping label -> original cycle count
33
+ U: Upper bound on sum of new costs (controls resolution vs. speed tradeoff)
34
+ solver_time_limit: Time limit in seconds for the solver
35
+
36
+ Returns:
37
+ Tuple of (normalized_costs dict, distortion F value)
38
+
39
+ Example:
40
+ >>> costs = {"GEMM": 1000, "EXP": 1000, "TMA_LOAD": 500}
41
+ >>> norm, F = normalize_costs(costs, U=10)
42
+ >>> # norm might be {"GEMM": 2, "EXP": 2, "TMA_LOAD": 1}
43
+ """
44
+ labels = list(original_costs.keys())
45
+ C = [original_costs[l] for l in labels]
46
+ n = len(C)
47
+
48
+ if n == 0:
49
+ return {}, 0.0
50
+
51
+ # Trivial case: all costs equal
52
+ if len(set(C)) == 1:
53
+ # All equal -> just set them all to 1
54
+ return {l: 1 for l in labels}, 0.0
55
+
56
+ # If all costs already small enough, return as-is
57
+ if sum(C) <= U:
58
+ return dict(original_costs), 0.0
59
+
60
+ # Set up the ILP problem
61
+ prob = pulp.LpProblem("CostNormalization", pulp.LpMinimize)
62
+
63
+ # Decision variables
64
+ # C'[i] for each instruction
65
+ c_prime = [pulp.LpVariable(f"c_{i}", lowBound=1, cat='Integer') for i in range(n)]
66
+
67
+ # F: the distortion bound
68
+ F = pulp.LpVariable("F", lowBound=0, cat='Integer')
69
+
70
+ # Objective: minimize F
71
+ prob += F
72
+
73
+ # Constraint: 1 ≤ Σ C'[i] ≤ U
74
+ prob += pulp.lpSum(c_prime) >= 1
75
+ prob += pulp.lpSum(c_prime) <= U
76
+
77
+ # Ratio constraints: ∀i,j: -F ≤ C[i]·C'[j] - C[j]·C'[i] ≤ F
78
+ for i in range(n):
79
+ for j in range(i + 1, n):
80
+ # C[i]*C'[j] - C[j]*C'[i] ≤ F
81
+ prob += C[i] * c_prime[j] - C[j] * c_prime[i] <= F
82
+ # -(C[i]*C'[j] - C[j]*C'[i]) ≤ F => C[j]*C'[i] - C[i]*C'[j] ≤ F
83
+ prob += C[j] * c_prime[i] - C[i] * c_prime[j] <= F
84
+
85
+ # Solve
86
+ solver = pulp.PULP_CBC_CMD(msg=0, timeLimit=solver_time_limit)
87
+ status = prob.solve(solver)
88
+
89
+ if status != pulp.constants.LpStatusOptimal:
90
+ # Fallback: proportional scaling
91
+ max_c = max(C)
92
+ scale = U / (n * max_c) if max_c > 0 else 1.0
93
+ normalized = {l: max(1, int(round(c * scale))) for l, c in zip(labels, C)}
94
+ return normalized, float('inf')
95
+
96
+ normalized = {}
97
+ for i, label in enumerate(labels):
98
+ val = int(round(pulp.value(c_prime[i])))
99
+ normalized[label] = max(1, val)
100
+
101
+ distortion = pulp.value(F)
102
+ return normalized, distortion
103
+
104
+
105
+ def apply_normalization(
106
+ original_costs: Dict[str, int],
107
+ normalized_costs: Dict[str, int],
108
+ edge_delays: List[Tuple[str, str, int]],
109
+ ) -> List[Tuple[str, str, int]]:
110
+ """Scale edge delays according to normalized instruction costs.
111
+
112
+ When instruction costs are normalized, edge delays (which depend on
113
+ the execution time of the source instruction) must be rescaled proportionally.
114
+
115
+ Args:
116
+ original_costs: Original cycle counts per instruction
117
+ normalized_costs: Normalized cycle counts per instruction
118
+ edge_delays: List of (src, dst, original_delay) tuples
119
+
120
+ Returns:
121
+ List of (src, dst, normalized_delay) tuples
122
+ """
123
+ result = []
124
+ for src, dst, d in edge_delays:
125
+ if src in original_costs and original_costs[src] > 0:
126
+ # Scale delay proportionally to source instruction's cost ratio
127
+ ratio = normalized_costs[src] / original_costs[src]
128
+ new_d = max(1, int(round(d * ratio)))
129
+ else:
130
+ new_d = d
131
+ result.append((src, dst, new_d))
132
+ return result