AshenNav commited on
Commit
cb0545b
·
verified ·
1 Parent(s): 2c2e63a

Upload twill/codegen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. twill/codegen.py +331 -0
twill/codegen.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code Generation: Emit pipelined and warp-annotated code from Twill's solution.
3
+
4
+ Based on Section 3.2 and 5 of the paper.
5
+
6
+ Generates:
7
+ 1. Prologue: primes the pipeline
8
+ 2. Steady State: the main loop body
9
+ 3. Epilogue: drains the pipeline
10
+
11
+ Each instruction is annotated with its assigned warp and synchronization barriers.
12
+ """
13
+
14
+ from typing import Dict, List, Optional, Tuple
15
+ from twill.graph import DependenceGraph, Instruction, DependenceEdge
16
+ from twill.smt_joint import JointSWPWSResult, WarpAssignment
17
+ import math
18
+
19
+
20
+ class PipelinedInstruction:
21
+ """A single instruction in the pipelined code.
22
+
23
+ Attributes:
24
+ name: Original instruction name
25
+ iteration: Which loop iteration this instance belongs to
26
+ cycle: Clock cycle in the pipelined schedule
27
+ warp: Assigned warp index
28
+ phase: "prologue", "steady", or "epilogue"
29
+ needs_barrier_before: Whether a barrier synchronization is needed before this
30
+ barrier_id: Identifier for the barrier (if needed)
31
+ """
32
+ def __init__(self, name: str, iteration: int, cycle: int, warp: int, phase: str):
33
+ self.name = name
34
+ self.iteration = iteration
35
+ self.cycle = cycle
36
+ self.warp = warp
37
+ self.phase = phase
38
+ self.needs_barrier_before: bool = False
39
+ self.barrier_id: Optional[str] = None
40
+
41
+ def __repr__(self):
42
+ barrier_str = f" [barrier:{self.barrier_id}]" if self.needs_barrier_before else ""
43
+ return f" cycle {self.cycle:3d}: {self.name}[iter={self.iteration}] @ warp {self.warp}{barrier_str}"
44
+
45
+
46
+ class PipelinedCode:
47
+ """Complete pipelined code with prologue, steady state, and epilogue."""
48
+
49
+ def __init__(self):
50
+ self.prologue: List[PipelinedInstruction] = []
51
+ self.steady_state: List[PipelinedInstruction] = []
52
+ self.epilogue: List[PipelinedInstruction] = []
53
+ self.initiation_interval: int = 0
54
+ self.schedule_length: int = 0
55
+ self.num_copies: int = 0
56
+
57
+ def all_instructions(self) -> List[PipelinedInstruction]:
58
+ return self.prologue + self.steady_state + self.epilogue
59
+
60
+ def __repr__(self):
61
+ lines = [
62
+ f"PipelinedCode(I={self.initiation_interval}, L={self.schedule_length}, "
63
+ f"copies={self.num_copies})",
64
+ "",
65
+ "=== PROLOGUE ===",
66
+ ]
67
+ for instr in self.prologue:
68
+ lines.append(str(instr))
69
+
70
+ lines.append("")
71
+ lines.append("=== STEADY STATE (loop body) ===")
72
+ for instr in self.steady_state:
73
+ lines.append(str(instr))
74
+
75
+ lines.append("")
76
+ lines.append("=== EPILOGUE ===")
77
+ for instr in self.epilogue:
78
+ lines.append(str(instr))
79
+
80
+ return "\n".join(lines)
81
+
82
+
83
+ def generate_pipelined_code(
84
+ graph: DependenceGraph,
85
+ result: JointSWPWSResult,
86
+ ) -> PipelinedCode:
87
+ """Generate pipelined and warp-annotated code from a Twill solution.
88
+
89
+ The code generation follows the standard modulo scheduling approach:
90
+ 1. Prologue: iterations 0..n_copies-2, instructions before their steady-state position
91
+ 2. Steady State: all n_copies overlapping, running in lockstep offset by I
92
+ 3. Epilogue: iterations 1..n_copies-1, instructions after steady-state ends
93
+
94
+ Args:
95
+ graph: The dependence graph
96
+ result: The joint SWP+WS result from Twill
97
+
98
+ Returns:
99
+ PipelinedCode with prologue, steady state, and epilogue
100
+ """
101
+ I = result.I
102
+ L = result.length
103
+ n_copies = result.num_copies
104
+ M = result.schedule
105
+ wa = result.warp_assignment
106
+
107
+ code = PipelinedCode()
108
+ code.initiation_interval = I
109
+ code.schedule_length = L
110
+ code.num_copies = n_copies
111
+
112
+ # Build the full straight-line schedule Q
113
+ # For each instruction v and each copy i:
114
+ # absolute_time(v, i) = M(v) + i * I
115
+ all_ops = []
116
+ for v in graph.V:
117
+ for i in range(n_copies):
118
+ abs_time = M[v.name] + i * I
119
+ if abs_time < L: # must finish within the schedule
120
+ warp = wa.warp_of(v.name)
121
+ all_ops.append((abs_time, v.name, i, warp))
122
+
123
+ # Sort by time
124
+ all_ops.sort(key=lambda x: (x[0], x[1]))
125
+
126
+ # Partition into prologue, steady state, epilogue
127
+ # Prologue: time < (n_copies - 1) * I
128
+ # Steady state: (n_copies - 1) * I <= time < n_copies * I (one full I window)
129
+ # Epilogue: time >= n_copies * I (but < L)
130
+
131
+ prologue_end = (n_copies - 1) * I if n_copies > 1 else 0
132
+ steady_end = n_copies * I if n_copies > 0 else I
133
+
134
+ for abs_time, name, iteration, warp in all_ops:
135
+ if n_copies <= 1:
136
+ # Only one copy -> everything is steady state
137
+ phase = "steady"
138
+ elif abs_time < prologue_end:
139
+ phase = "prologue"
140
+ elif abs_time < steady_end:
141
+ phase = "steady"
142
+ else:
143
+ phase = "epilogue"
144
+
145
+ instr = PipelinedInstruction(
146
+ name=name,
147
+ iteration=iteration,
148
+ cycle=abs_time,
149
+ warp=warp,
150
+ phase=phase,
151
+ )
152
+
153
+ if phase == "prologue":
154
+ code.prologue.append(instr)
155
+ elif phase == "steady":
156
+ code.steady_state.append(instr)
157
+ else:
158
+ code.epilogue.append(instr)
159
+
160
+ # Add barrier annotations for cross-warp dependencies
161
+ _annotate_barriers(graph, result, code)
162
+
163
+ return code
164
+
165
+
166
+ def _annotate_barriers(
167
+ graph: DependenceGraph,
168
+ result: JointSWPWSResult,
169
+ code: PipelinedCode,
170
+ ):
171
+ """Annotate instructions that need barrier synchronization.
172
+
173
+ A barrier is needed when:
174
+ 1. There's a dependence edge (u, v, d, δ)
175
+ 2. u and v are assigned to different warps
176
+
177
+ On Hopper/Blackwell, this uses mbarrier-based synchronization.
178
+ """
179
+ wa = result.warp_assignment
180
+
181
+ # Build a map from (name, iteration) to instruction in the code
182
+ instr_map = {}
183
+ for instr in code.all_instructions():
184
+ instr_map[instr.name, instr.iteration] = instr
185
+
186
+ barrier_counter = 0
187
+ for edge in graph.E:
188
+ src_warp = wa.warp_of(edge.src)
189
+ dst_warp = wa.warp_of(edge.dst)
190
+
191
+ if src_warp != dst_warp:
192
+ # Cross-warp dependency -> needs barrier
193
+ for i in range(result.num_copies):
194
+ j = i + edge.iteration_delay
195
+ if (edge.dst, j) in instr_map:
196
+ target = instr_map[edge.dst, j]
197
+ target.needs_barrier_before = True
198
+ target.barrier_id = f"bar_{barrier_counter}"
199
+ barrier_counter += 1
200
+
201
+
202
+ def generate_pseudocode(
203
+ graph: DependenceGraph,
204
+ result: JointSWPWSResult,
205
+ ) -> str:
206
+ """Generate human-readable pseudocode from the Twill solution.
207
+
208
+ Returns a string of annotated pseudocode showing the pipelined schedule
209
+ with warp assignments and barriers.
210
+ """
211
+ code = generate_pipelined_code(graph, result)
212
+ wa = result.warp_assignment
213
+ I = result.I
214
+
215
+ lines = []
216
+ lines.append(f"// Twill-generated pipelined schedule")
217
+ lines.append(f"// Initiation Interval I = {I}")
218
+ lines.append(f"// Schedule Length L = {result.length}")
219
+ lines.append(f"// Overlapping copies = {result.num_copies}")
220
+ lines.append(f"//")
221
+
222
+ # Warp assignment summary
223
+ lines.append(f"// Warp Assignment:")
224
+ for v in graph.V:
225
+ warp = wa.warp_of(v.name)
226
+ warp_label = wa.warp_names.get(warp, f"warp_{warp}")
227
+ var_lat = " [variable-latency]" if v.variable_latency else ""
228
+ lines.append(f"// {v.name} -> {warp_label}{var_lat}")
229
+ lines.append(f"")
230
+
231
+ # Prologue
232
+ if code.prologue:
233
+ lines.append(f"// ---- PROLOGUE ----")
234
+ for instr in code.prologue:
235
+ barrier = f" mbarrier.wait({instr.barrier_id});" if instr.needs_barrier_before else ""
236
+ lines.append(f"{barrier}")
237
+ lines.append(f"/* cycle {instr.cycle}, warp {instr.warp} */ "
238
+ f"{instr.name}_{instr.iteration} = {instr.name}(...);")
239
+ lines.append(f"")
240
+
241
+ # Steady state
242
+ lines.append(f"// ---- STEADY STATE (for i in range(N)): ----")
243
+ for instr in code.steady_state:
244
+ barrier = f" mbarrier.wait({instr.barrier_id});\n" if instr.needs_barrier_before else ""
245
+ lines.append(f"{barrier}/* cycle {instr.cycle}, warp {instr.warp} */ "
246
+ f"{instr.name} = {instr.name}(...); // iter offset={instr.iteration}")
247
+ lines.append(f"")
248
+
249
+ # Epilogue
250
+ if code.epilogue:
251
+ lines.append(f"// ---- EPILOGUE ----")
252
+ for instr in code.epilogue:
253
+ barrier = f" mbarrier.wait({instr.barrier_id});\n" if instr.needs_barrier_before else ""
254
+ lines.append(f"{barrier}/* cycle {instr.cycle}, warp {instr.warp} */ "
255
+ f"{instr.name}_{instr.iteration} = {instr.name}(...);")
256
+
257
+ return "\n".join(lines)
258
+
259
+
260
+ def generate_cuda_skeleton(
261
+ graph: DependenceGraph,
262
+ result: JointSWPWSResult,
263
+ ) -> str:
264
+ """Generate a CUDA C++ skeleton from the Twill solution.
265
+
266
+ This produces a template that shows the warp-specialized structure
267
+ with the correct barriers and pipeline stages.
268
+ """
269
+ wa = result.warp_assignment
270
+ machine = graph.machine
271
+ code = generate_pipelined_code(graph, result)
272
+
273
+ lines = []
274
+ lines.append(f"// CUDA C++ skeleton generated by Twill")
275
+ lines.append(f"// Target: {machine.name}")
276
+ lines.append(f"// I={result.I}, L={result.length}, copies={result.num_copies}")
277
+ lines.append(f"")
278
+ lines.append(f"#include <cuda.h>")
279
+ lines.append(f"#include <cuda/barrier>")
280
+ lines.append(f"")
281
+ lines.append(f"__global__ void twill_kernel(...) {{")
282
+ lines.append(f" const int warp_id = threadIdx.x / 32;")
283
+ lines.append(f" ")
284
+
285
+ # Group by warp
286
+ warp_groups = {}
287
+ for w in range(machine.num_warps):
288
+ instrs = wa.instructions_on_warp(w)
289
+ if instrs:
290
+ warp_groups[w] = instrs
291
+
292
+ for w, instrs in warp_groups.items():
293
+ warp_label = wa.warp_names.get(w, f"warp {w}")
294
+ lines.append(f" if (warp_id == {w}) {{ // {warp_label}")
295
+ lines.append(f" // Operations: {instrs}")
296
+
297
+ # Show prologue operations for this warp
298
+ warp_prologue = [i for i in code.prologue if i.warp == w]
299
+ if warp_prologue:
300
+ lines.append(f" // --- Prologue ---")
301
+ for instr in warp_prologue:
302
+ if instr.needs_barrier_before:
303
+ lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});")
304
+ lines.append(f" {instr.name}(...); // iter {instr.iteration}")
305
+
306
+ # Show steady-state loop for this warp
307
+ warp_steady = [i for i in code.steady_state if i.warp == w]
308
+ if warp_steady:
309
+ lines.append(f" // --- Steady State ---")
310
+ lines.append(f" for (int i = 0; i < N; i++) {{")
311
+ for instr in warp_steady:
312
+ if instr.needs_barrier_before:
313
+ lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});")
314
+ lines.append(f" {instr.name}(...);")
315
+ lines.append(f" }}")
316
+
317
+ # Show epilogue for this warp
318
+ warp_epilogue = [i for i in code.epilogue if i.warp == w]
319
+ if warp_epilogue:
320
+ lines.append(f" // --- Epilogue ---")
321
+ for instr in warp_epilogue:
322
+ if instr.needs_barrier_before:
323
+ lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});")
324
+ lines.append(f" {instr.name}(...); // iter {instr.iteration}")
325
+
326
+ lines.append(f" }}")
327
+ lines.append(f" ")
328
+
329
+ lines.append(f"}}")
330
+
331
+ return "\n".join(lines)