AshenNav commited on
Commit
62cd880
·
verified ·
1 Parent(s): b84ba2b

Upload twill/smt_joint.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. twill/smt_joint.py +426 -0
twill/smt_joint.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 2: Joint SWP + WS via SMT Solver (Z3).
3
+
4
+ Based on Section 4 of the paper (Figures 4, 5, and 6).
5
+
6
+ Takes the initial modulo schedule M and initiation interval I from Phase 1,
7
+ then formulates constraints over a straight-line program Q derived from M.
8
+
9
+ Constraint groups:
10
+ 1. Modulo Scheduling Constraints (Figure 4):
11
+ - Uniqueness, Consistency, Completion, Dependence, Capacity
12
+ 2. Memory Allocation Constraints (Figure 5):
13
+ - Memory Capacity, Init, LiveProp-1, LiveProp-2, DeadProp-1, DeadProp-2
14
+ 3. Warp Assignment Constraints (Figure 6):
15
+ - WarpUniqueness, VariableLatency, WarpCapacity, WarpSync
16
+
17
+ Uses Z3 SMT solver (QFLIA theory - quantifier-free linear integer arithmetic).
18
+ """
19
+
20
+ import z3
21
+ import numpy as np
22
+ from typing import Dict, List, Optional, Tuple, Set
23
+ from twill.graph import DependenceGraph, Instruction, DependenceEdge, MachineDescription
24
+ from twill.modulo_scheduler import ModuloScheduleResult
25
+
26
+
27
+ class WarpAssignment:
28
+ """Result of warp assignment.
29
+
30
+ Attributes:
31
+ assignment: Dict mapping instruction name -> warp index
32
+ warp_names: Optional names for warps (e.g., "producer", "consumer_A", "consumer_B")
33
+ """
34
+ def __init__(self, assignment: Dict[str, int], num_warps: int):
35
+ self.assignment = assignment
36
+ self.num_warps = num_warps
37
+ self.warp_names: Dict[int, str] = {}
38
+
39
+ def warp_of(self, instr_name: str) -> int:
40
+ return self.assignment[instr_name]
41
+
42
+ def instructions_on_warp(self, warp_idx: int) -> List[str]:
43
+ return [name for name, w in self.assignment.items() if w == warp_idx]
44
+
45
+ def __repr__(self):
46
+ warp_groups = {}
47
+ for name, w in self.assignment.items():
48
+ warp_groups.setdefault(w, []).append(name)
49
+ lines = []
50
+ for w in sorted(warp_groups):
51
+ label = self.warp_names.get(w, f"Warp {w}")
52
+ lines.append(f" {label}: {warp_groups[w]}")
53
+ return "WarpAssignment(\n" + "\n".join(lines) + "\n)"
54
+
55
+
56
+ class JointSWPWSResult:
57
+ """Result of the joint SWP + WS optimization.
58
+
59
+ Attributes:
60
+ schedule: Dict mapping instruction name -> clock cycle in Q*
61
+ initiation_interval: I
62
+ length: L (schedule length)
63
+ warp_assignment: WarpAssignment
64
+ num_copies: ceil(L/I)
65
+ """
66
+ def __init__(
67
+ self,
68
+ schedule: Dict[str, int],
69
+ I: int,
70
+ L: int,
71
+ warp_assignment: WarpAssignment,
72
+ ):
73
+ self.schedule = schedule
74
+ self.initiation_interval = I
75
+ self.length = L
76
+ self.warp_assignment = warp_assignment
77
+
78
+ @property
79
+ def I(self) -> int:
80
+ return self.initiation_interval
81
+
82
+ @property
83
+ def num_copies(self) -> int:
84
+ return int(np.ceil(self.length / self.I))
85
+
86
+ def __repr__(self):
87
+ return (
88
+ f"JointSWPWSResult(\n"
89
+ f" I={self.I}, L={self.length}, copies={self.num_copies}\n"
90
+ f" schedule={self.schedule}\n"
91
+ f" {self.warp_assignment}\n"
92
+ f")"
93
+ )
94
+
95
+
96
+ def swp_and_ws(
97
+ graph: DependenceGraph,
98
+ initial_schedule: ModuloScheduleResult,
99
+ I: int,
100
+ L: int,
101
+ enable_memory_constraints: bool = True,
102
+ enable_warp_constraints: bool = True,
103
+ timeout_ms: int = 120000,
104
+ verbose: bool = False,
105
+ ) -> Optional[JointSWPWSResult]:
106
+ """Solve the joint SWP + WS problem using Z3.
107
+
108
+ This is the core of Twill's Phase 2.
109
+
110
+ Args:
111
+ graph: Loop dependence graph
112
+ initial_schedule: Phase 1 result (initial M and I)
113
+ I: Initiation interval (from Phase 1)
114
+ L: Schedule length to try
115
+ enable_memory_constraints: Whether to include memory capacity constraints
116
+ enable_warp_constraints: Whether to include warp assignment constraints
117
+ timeout_ms: Z3 timeout in milliseconds
118
+ verbose: Print constraint statistics
119
+
120
+ Returns:
121
+ JointSWPWSResult if satisfiable, None otherwise
122
+ """
123
+ machine = graph.machine
124
+ V = graph.V
125
+ E = graph.E
126
+ n_copies = int(np.ceil(L / I))
127
+ # T = total time window for the straight-line program Q
128
+ # Must accommodate all overlapping copies: copy i starts at offset i*I
129
+ # The last copy (n_copies-1) starts at (n_copies-1)*I and needs L time steps
130
+ # So T = L + (n_copies - 1) * I
131
+ T = L + (n_copies - 1) * I
132
+
133
+ if verbose:
134
+ print(f" SMT: I={I}, L={L}, T={T}, copies={n_copies}")
135
+ print(f" SMT: |V|={len(V)}, |E|={len(E)}")
136
+
137
+ # Create Z3 solver
138
+ solver = z3.Solver()
139
+ solver.set("timeout", timeout_ms)
140
+
141
+ # ============================================================
142
+ # Boolean variables: op[v, i, t]
143
+ # True if instruction v, iteration copy i, is scheduled at clock cycle t
144
+ # ============================================================
145
+ op = {}
146
+ for v in V:
147
+ for i in range(n_copies):
148
+ for t in range(T):
149
+ op[v.name, i, t] = z3.Bool(f"op_{v.name}_{i}_{t}")
150
+
151
+ # ============================================================
152
+ # Figure 4: Modulo Scheduling Constraints
153
+ # ============================================================
154
+
155
+ # (Uniqueness): ∀v,i: Σ_t op[v,i,t] = 1
156
+ for v in V:
157
+ for i in range(n_copies):
158
+ # Exactly one time slot: use PbEq (pseudo-boolean exactly-k)
159
+ solver.add(z3.PbEq([(op[v.name, i, t], 1) for t in range(T)], 1))
160
+
161
+ # (Consistency): ∀v, i∈[1,n_copies), t: op[v,0,t] => op[v,i,t+i*I]
162
+ for v in V:
163
+ for i in range(1, n_copies):
164
+ for t in range(T):
165
+ t_shifted = t + i * I
166
+ if 0 <= t_shifted < T:
167
+ solver.add(z3.Implies(op[v.name, 0, t], op[v.name, i, t_shifted]))
168
+ else:
169
+ # If shifted time is out of bounds, op[v,0,t] must be false
170
+ # (because consistency requires scheduling at t_shifted which doesn't exist)
171
+ # Actually: if t+i*I >= T, then this copy i can't start from t in iteration 0
172
+ # Only add if the shift is valid
173
+ if t_shifted >= T:
174
+ solver.add(z3.Implies(op[v.name, 0, t], z3.BoolVal(False)))
175
+
176
+ # (Completion): ∀v,i,t: t + cycles(v) > T => ¬op[v,i,t]
177
+ for v in V:
178
+ for i in range(n_copies):
179
+ for t in range(T):
180
+ if t + v.cycles > T:
181
+ solver.add(z3.Not(op[v.name, i, t]))
182
+
183
+ # (Dependence): ∀i,t, (u,v,d,δ)∈E, t'∈[0,t+d): op[u,i,t] => ¬op[v,i+δ,t']
184
+ for edge in E:
185
+ u_name = edge.src
186
+ v_name = edge.dst
187
+ d = edge.delay
188
+ delta = edge.iteration_delay
189
+
190
+ for i in range(n_copies):
191
+ j = i + delta # target iteration copy
192
+ if j >= n_copies:
193
+ continue # skip if target iteration is beyond our window
194
+
195
+ for t in range(T):
196
+ # If u is scheduled at t, then v cannot be at any t' < t + d
197
+ implications = []
198
+ for t_prime in range(min(t + d, T)):
199
+ implications.append(z3.Not(op[v_name, j, t_prime]))
200
+
201
+ if implications:
202
+ solver.add(z3.Implies(op[u_name, i, t], z3.And(implications)))
203
+
204
+ # (Capacity): ∀t,f: Σ_{v,i,c∈[0,cycles(v))} op[v,i,t-c]·RRT[v][f,c] ≤ cap(f)
205
+ for t in range(T):
206
+ for f_idx, f_name in enumerate(machine.functional_units):
207
+ cap = machine.capacity(f_name)
208
+ if cap <= 0:
209
+ continue
210
+
211
+ terms = []
212
+ for v in V:
213
+ for i in range(n_copies):
214
+ for c in range(v.cycles):
215
+ usage = int(v.rrt[c, f_idx])
216
+ if usage > 0 and 0 <= t - c < T:
217
+ terms.append((op[v.name, i, t - c], usage))
218
+
219
+ if terms:
220
+ solver.add(z3.PbLe(terms, cap))
221
+
222
+ # ============================================================
223
+ # Figure 5: Memory Allocation Constraints (optional)
224
+ # ============================================================
225
+ live = {}
226
+ if enable_memory_constraints and machine.memory_spaces:
227
+ # live[v, i, t]: result of v (iteration i) is live at time t
228
+ for v in V:
229
+ if not v.memory_footprint:
230
+ continue
231
+ for i in range(n_copies):
232
+ for t in range(T + 1): # include T for Init constraint
233
+ live[v.name, i, t] = z3.Bool(f"live_{v.name}_{i}_{t}")
234
+
235
+ for v in V:
236
+ if not v.memory_footprint:
237
+ continue
238
+ for i in range(n_copies):
239
+ # (Init): loop-carried results of last copy are live at T
240
+ has_loop_carried_out = graph.has_loop_carried_output(v.name)
241
+ if i == n_copies - 1:
242
+ if has_loop_carried_out:
243
+ solver.add(live[v.name, i, T])
244
+ else:
245
+ solver.add(z3.Not(live[v.name, i, T]))
246
+
247
+ # (LiveProp-1): (live[v,i,t] ∧ op[v,i,t]) => ¬live[v,i,t-1]
248
+ for t in range(1, T + 1):
249
+ if (v.name, i, t) in live and (v.name, i, t - 1) in live:
250
+ if t < T: # op only defined for t < T
251
+ solver.add(z3.Implies(
252
+ z3.And(live[v.name, i, t], op[v.name, i, t]),
253
+ z3.Not(live[v.name, i, t - 1])
254
+ ))
255
+
256
+ # (LiveProp-2): (live[v,i,t] ∧ ¬op[v,i,t]) => live[v,i,t-1]
257
+ for t in range(1, T + 1):
258
+ if (v.name, i, t) in live and (v.name, i, t - 1) in live:
259
+ if t < T:
260
+ solver.add(z3.Implies(
261
+ z3.And(live[v.name, i, t], z3.Not(op[v.name, i, t])),
262
+ live[v.name, i, t - 1]
263
+ ))
264
+
265
+ # (DeadProp): propagation of deadness
266
+ for t in range(1, T + 1):
267
+ if (v.name, i, t) not in live or (v.name, i, t - 1) not in live:
268
+ continue
269
+
270
+ # Collect all consumers of v at time t
271
+ consumer_ops = []
272
+ for edge in graph.outgoing_edges(v.name):
273
+ j = i + edge.iteration_delay
274
+ if j < n_copies and t < T:
275
+ consumer_ops.append(op[edge.dst, j, t])
276
+
277
+ if consumer_ops:
278
+ # (DeadProp-1): (¬live[v,i,t] ∧ ∨consumers_active) => live[v,i,t-1]
279
+ solver.add(z3.Implies(
280
+ z3.And(z3.Not(live[v.name, i, t]), z3.Or(consumer_ops)),
281
+ live[v.name, i, t - 1]
282
+ ))
283
+ # (DeadProp-2): (¬live[v,i,t] ∧ ∧¬consumers_active) => ¬live[v,i,t-1]
284
+ solver.add(z3.Implies(
285
+ z3.And(z3.Not(live[v.name, i, t]),
286
+ *[z3.Not(c) for c in consumer_ops]),
287
+ z3.Not(live[v.name, i, t - 1])
288
+ ))
289
+ else:
290
+ # No consumers at this time -> propagate deadness
291
+ solver.add(z3.Implies(
292
+ z3.Not(live[v.name, i, t]),
293
+ z3.Not(live[v.name, i, t - 1])
294
+ ))
295
+
296
+ # (Memory Capacity): ∀t,m: Σ_{v,i} live[v,i,t]·footprint(v,m) ≤ capacity(m)
297
+ for t in range(T):
298
+ for mem_name, mem_cap in machine.memory_spaces.items():
299
+ terms = []
300
+ for v in V:
301
+ if mem_name in v.memory_footprint and v.memory_footprint[mem_name] > 0:
302
+ for i in range(n_copies):
303
+ if (v.name, i, t) in live:
304
+ terms.append(
305
+ (live[v.name, i, t], v.memory_footprint[mem_name])
306
+ )
307
+ if terms:
308
+ solver.add(z3.PbLe(terms, mem_cap))
309
+
310
+ # ============================================================
311
+ # Figure 6: Warp Assignment Constraints (optional)
312
+ # ============================================================
313
+ opw = {}
314
+ if enable_warp_constraints:
315
+ num_warps = machine.num_warps
316
+ W_vl = machine.variable_latency_warp
317
+
318
+ # opw[v, w]: instruction v is assigned to warp w
319
+ for v in V:
320
+ for w in range(num_warps):
321
+ opw[v.name, w] = z3.Bool(f"opw_{v.name}_{w}")
322
+
323
+ # (WarpUniqueness): ∀v: Σ_w opw[v,w] = 1
324
+ for v in V:
325
+ solver.add(z3.PbEq([(opw[v.name, w], 1) for w in range(num_warps)], 1))
326
+
327
+ # (VariableLatency): ∀v: variable_latency(v) <=> opw[v, W_vl]
328
+ # Only enforce if there ARE variable-latency ops (otherwise W_vl is unused
329
+ # and all warps should be available for compute)
330
+ has_var_lat_ops = any(v.variable_latency for v in V)
331
+ for v in V:
332
+ if v.variable_latency:
333
+ # Variable-latency ops MUST go to W_vl
334
+ solver.add(opw[v.name, W_vl])
335
+ elif has_var_lat_ops:
336
+ # Non-variable-latency ops must NOT go to W_vl (reserved for var-lat)
337
+ solver.add(z3.Not(opw[v.name, W_vl]))
338
+
339
+ # (WarpCapacity): Per-warp resource constraints
340
+ # For each time slot t ∈ [0, I), warp w, functional unit f:
341
+ # Σ_{v,c} opw[v,w] · op[v,0,t_shift] · RRT[v][f,c] ≤ per_warp_cap(f)
342
+ #
343
+ # Simplified: for each warp, the total resource usage across
344
+ # its assigned instructions in a single I-window must fit
345
+ # This is a tighter constraint that prevents resource conflicts within a warp
346
+ for w in range(num_warps):
347
+ for f_idx, f_name in enumerate(machine.functional_units):
348
+ cap = machine.capacity(f_name)
349
+ if cap <= 0:
350
+ continue
351
+
352
+ # Total usage across one initiation interval per warp
353
+ terms = []
354
+ for v in V:
355
+ total_fu_use = int(v.rrt[:, f_idx].sum())
356
+ if total_fu_use > 0:
357
+ terms.append((opw[v.name, w], total_fu_use))
358
+
359
+ if terms:
360
+ # Per-warp capacity: at most cap per I cycles
361
+ solver.add(z3.PbLe(terms, cap * I))
362
+
363
+ # (Cross-warp synchronization): When two ops are on different warps
364
+ # and have a dependence, a barrier is needed (mbarrier on Hopper/Blackwell).
365
+ # This is modeled implicitly: the dependence constraints already enforce
366
+ # timing, and the WS assignment determines which barriers are needed.
367
+ # Twill emits barrier annotations in code generation.
368
+
369
+ # ============================================================
370
+ # Solve
371
+ # ============================================================
372
+ if verbose:
373
+ stats = solver.statistics()
374
+ print(f" SMT: Constraints added. Solving...")
375
+
376
+ result = solver.check()
377
+
378
+ if result != z3.sat:
379
+ if verbose:
380
+ print(f" SMT: {result}")
381
+ return None
382
+
383
+ model = solver.model()
384
+
385
+ # Extract schedule M*
386
+ new_schedule = {}
387
+ for v in V:
388
+ for t in range(T):
389
+ if z3.is_true(model.evaluate(op[v.name, 0, t])):
390
+ new_schedule[v.name] = t
391
+ break
392
+ else:
393
+ # Should not happen if constraints are correct
394
+ raise RuntimeError(f"No time slot found for instruction {v.name} in iteration 0")
395
+
396
+ # Extract warp assignment A*
397
+ warp_assign = {}
398
+ if enable_warp_constraints:
399
+ for v in V:
400
+ for w in range(machine.num_warps):
401
+ if z3.is_true(model.evaluate(opw[v.name, w])):
402
+ warp_assign[v.name] = w
403
+ break
404
+ else:
405
+ raise RuntimeError(f"No warp found for instruction {v.name}")
406
+ else:
407
+ # Default: all on warp 0
408
+ for v in V:
409
+ warp_assign[v.name] = 0
410
+
411
+ wa = WarpAssignment(warp_assign, machine.num_warps)
412
+
413
+ # Label warps
414
+ if enable_warp_constraints:
415
+ vl_warp = machine.variable_latency_warp
416
+ wa.warp_names[vl_warp] = f"Warp {vl_warp} (variable-latency/producer)"
417
+ for w in range(machine.num_warps):
418
+ if w != vl_warp and w not in wa.warp_names:
419
+ wa.warp_names[w] = f"Warp {w} (compute/consumer)"
420
+
421
+ return JointSWPWSResult(
422
+ schedule=new_schedule,
423
+ I=I,
424
+ L=L,
425
+ warp_assignment=wa,
426
+ )