AshenNav commited on
Commit
b84ba2b
·
verified ·
1 Parent(s): 744cb86

Upload twill/modulo_scheduler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. twill/modulo_scheduler.py +235 -0
twill/modulo_scheduler.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 1: Optimal Modulo Scheduling via Integer Linear Programming (ZLP).
3
+
4
+ Based on Section 3.1, 4.1, and 5.1 of the paper.
5
+ Uses the ILP formulation from Stoutchinin et al. (referenced as [stoutchinin-ilp]).
6
+
7
+ The modulo scheduling problem:
8
+ Given G = (V, E) and target initiation interval I,
9
+ find M: V -> [0, L) such that:
10
+ 1. Dependence: ∀(u,v,d,δ)∈E: M(v) - M(u) + I·δ ≥ d
11
+ 2. Resource: modular RRT fits within machine capacities
12
+ 3. Minimize L (schedule length) subject to the above
13
+
14
+ Uses CBC solver via PuLP.
15
+ """
16
+
17
+ import pulp
18
+ import numpy as np
19
+ from typing import Dict, List, Optional, Tuple
20
+ from twill.graph import DependenceGraph, Instruction, DependenceEdge
21
+
22
+
23
+ class ModuloScheduleResult:
24
+ """Result of modulo scheduling.
25
+
26
+ Attributes:
27
+ schedule: Dict mapping instruction name -> clock cycle M(v)
28
+ initiation_interval: I
29
+ length: L (total schedule length)
30
+ num_copies: ceil(L/I) - number of overlapping iterations
31
+ """
32
+
33
+ def __init__(self, schedule: Dict[str, int], I: int):
34
+ self.schedule = schedule
35
+ self.initiation_interval = I
36
+ self._length = None
37
+
38
+ @property
39
+ def I(self) -> int:
40
+ return self.initiation_interval
41
+
42
+ @property
43
+ def length(self) -> int:
44
+ """L: total schedule length (max M(v) + cycles(v) across all instructions)."""
45
+ if self._length is not None:
46
+ return self._length
47
+ return max(self.schedule.values()) + 1 # +1 because 0-indexed
48
+
49
+ @length.setter
50
+ def length(self, val: int):
51
+ self._length = val
52
+
53
+ @property
54
+ def num_copies(self) -> int:
55
+ """ceil(L/I) - number of overlapping iteration copies."""
56
+ return int(np.ceil(self.length / self.I))
57
+
58
+ def __repr__(self):
59
+ return (f"ModuloSchedule(I={self.I}, L={self.length}, copies={self.num_copies}, "
60
+ f"schedule={self.schedule})")
61
+
62
+
63
+ def optimal_modulo_schedule(
64
+ graph: DependenceGraph,
65
+ target_I: int,
66
+ solver_time_limit: int = 120,
67
+ verbose: bool = False,
68
+ ) -> Optional[ModuloScheduleResult]:
69
+ """Find an optimal modulo schedule with the given initiation interval.
70
+
71
+ Uses ILP formulation: minimize L subject to dependence and resource constraints.
72
+
73
+ Args:
74
+ graph: The loop dependence graph
75
+ target_I: Target initiation interval
76
+ solver_time_limit: Time limit for the solver in seconds
77
+ verbose: Print solver output
78
+
79
+ Returns:
80
+ ModuloScheduleResult if feasible, None if infeasible for this I
81
+ """
82
+ I = target_I
83
+ V = graph.V
84
+ E = graph.E
85
+ machine = graph.machine
86
+ n = len(V)
87
+
88
+ # Variable: M(v) for each instruction v - the clock cycle it's scheduled at
89
+ prob = pulp.LpProblem(f"ModuloSchedule_I{I}", pulp.LpMinimize)
90
+
91
+ # Decision variables: M[v] ∈ [0, big_M)
92
+ # Upper bound on schedule length: heuristic
93
+ max_cycles = max(v.cycles for v in V)
94
+ big_M = I * (n + 1) * max_cycles # generous upper bound
95
+
96
+ M = {}
97
+ for v in V:
98
+ M[v.name] = pulp.LpVariable(f"M_{v.name}", lowBound=0, upBound=big_M, cat='Integer')
99
+
100
+ # Auxiliary variable for schedule length L = max(M(v) + cycles(v))
101
+ L = pulp.LpVariable("L", lowBound=1, upBound=big_M, cat='Integer')
102
+
103
+ # Objective: minimize L
104
+ prob += L
105
+
106
+ # Constraint: L >= M(v) + cycles(v) for all v
107
+ for v in V:
108
+ prob += L >= M[v.name] + v.cycles
109
+
110
+ # Dependence constraints (Section 3.1):
111
+ # ∀(u,v,d,δ)∈E: M(v) + I·δ ≥ M(u) + d
112
+ # => M(v) - M(u) ≥ d - I·δ
113
+ for e in E:
114
+ prob += M[e.dst] - M[e.src] >= e.delay - I * e.iteration_delay
115
+
116
+ # Resource constraints (modular):
117
+ # For each functional unit f and each time slot t ∈ [0, I):
118
+ # Σ_{v} (number of cycles in [0, cycles(v)) where v occupies f at (M(v)+c) mod I == t) ≤ cap(f)
119
+ #
120
+ # This is the standard modular resource constraint.
121
+ # Since M(v) is a variable, we can't directly encode this as a linear constraint.
122
+ # Instead, we use a linearization with binary indicator variables.
123
+
124
+ # For each instruction v, we introduce binary variables slot[v,s] indicating
125
+ # M(v) mod I == s
126
+ slot = {}
127
+ for v in V:
128
+ for s in range(I):
129
+ slot[v.name, s] = pulp.LpVariable(f"slot_{v.name}_{s}", cat='Binary')
130
+ # Exactly one slot
131
+ prob += pulp.lpSum(slot[v.name, s] for s in range(I)) == 1
132
+ # Link M(v) to slot: M(v) = q*I + s for some integer q
133
+ q_v = pulp.LpVariable(f"q_{v.name}", lowBound=0, upBound=big_M // max(I, 1), cat='Integer')
134
+ prob += M[v.name] == q_v * I + pulp.lpSum(s * slot[v.name, s] for s in range(I))
135
+
136
+ # Modular resource constraint:
137
+ # For each time slot t ∈ [0, I) and functional unit f:
138
+ # Σ_{v ∈ V} Σ_{c ∈ [0, cycles(v))} RRT[v][c, f] * slot[v, (t-c) mod I] ≤ cap(f)
139
+ for t in range(I):
140
+ for f_idx, f_name in enumerate(machine.functional_units):
141
+ cap = machine.capacity(f_name)
142
+ if cap <= 0:
143
+ continue
144
+
145
+ terms = []
146
+ for v in V:
147
+ for c in range(v.cycles):
148
+ usage = int(v.rrt[c, f_idx])
149
+ if usage > 0:
150
+ s = (t - c) % I
151
+ terms.append(usage * slot[v.name, s])
152
+
153
+ if terms:
154
+ prob += pulp.lpSum(terms) <= cap
155
+
156
+ # Solve
157
+ solver = pulp.PULP_CBC_CMD(msg=1 if verbose else 0, timeLimit=solver_time_limit)
158
+ status = prob.solve(solver)
159
+
160
+ if status != pulp.constants.LpStatusOptimal:
161
+ return None
162
+
163
+ # Extract solution
164
+ schedule = {}
165
+ for v in V:
166
+ schedule[v.name] = int(round(pulp.value(M[v.name])))
167
+
168
+ result = ModuloScheduleResult(schedule, I)
169
+ result.length = int(round(pulp.value(L)))
170
+
171
+ return result
172
+
173
+
174
+ def compute_modular_rrt(
175
+ graph: DependenceGraph,
176
+ schedule: ModuloScheduleResult,
177
+ ) -> np.ndarray:
178
+ """Compute the modular RRT for a given schedule.
179
+
180
+ The modular RRT shows resource usage per time slot in [0, I).
181
+ modular_rrt[t, f] = total usage of functional unit f at time slot t in steady state.
182
+
183
+ Returns:
184
+ np.ndarray of shape (I, num_functional_units)
185
+ """
186
+ I = schedule.I
187
+ num_fus = graph.machine.num_functional_units
188
+ mod_rrt = np.zeros((I, num_fus), dtype=int)
189
+
190
+ for v in graph.V:
191
+ m_v = schedule.schedule[v.name]
192
+ for c in range(v.cycles):
193
+ t = (m_v + c) % I
194
+ for f in range(num_fus):
195
+ mod_rrt[t, f] += int(v.rrt[c, f])
196
+
197
+ return mod_rrt
198
+
199
+
200
+ def validate_schedule(
201
+ graph: DependenceGraph,
202
+ schedule: ModuloScheduleResult,
203
+ ) -> Tuple[bool, List[str]]:
204
+ """Validate that a modulo schedule satisfies all constraints.
205
+
206
+ Returns:
207
+ Tuple of (is_valid, list of violation messages)
208
+ """
209
+ violations = []
210
+ I = schedule.I
211
+ M = schedule.schedule
212
+
213
+ # Check dependence constraints
214
+ for e in graph.E:
215
+ m_src = M[e.src]
216
+ m_dst = M[e.dst]
217
+ if m_dst + I * e.iteration_delay < m_src + e.delay:
218
+ violations.append(
219
+ f"Dependence violation: {e.src}({m_src}) -> {e.dst}({m_dst}), "
220
+ f"need M({e.dst}) + I*{e.iteration_delay} >= M({e.src}) + {e.delay}, "
221
+ f"got {m_dst + I * e.iteration_delay} < {m_src + e.delay}"
222
+ )
223
+
224
+ # Check resource constraints
225
+ mod_rrt = compute_modular_rrt(graph, schedule)
226
+ cap_vec = graph.machine.capacity_vector
227
+ for t in range(I):
228
+ for f in range(graph.machine.num_functional_units):
229
+ if mod_rrt[t, f] > cap_vec[f]:
230
+ violations.append(
231
+ f"Resource violation at t={t}, {graph.machine.functional_units[f]}: "
232
+ f"usage={mod_rrt[t, f]} > capacity={cap_vec[f]}"
233
+ )
234
+
235
+ return len(violations) == 0, violations