twill-swp-ws / twill /codegen.py
AshenNav's picture
Upload twill/codegen.py with huggingface_hub
cb0545b verified
"""
Code Generation: Emit pipelined and warp-annotated code from Twill's solution.
Based on Section 3.2 and 5 of the paper.
Generates:
1. Prologue: primes the pipeline
2. Steady State: the main loop body
3. Epilogue: drains the pipeline
Each instruction is annotated with its assigned warp and synchronization barriers.
"""
from typing import Dict, List, Optional, Tuple
from twill.graph import DependenceGraph, Instruction, DependenceEdge
from twill.smt_joint import JointSWPWSResult, WarpAssignment
import math
class PipelinedInstruction:
"""A single instruction in the pipelined code.
Attributes:
name: Original instruction name
iteration: Which loop iteration this instance belongs to
cycle: Clock cycle in the pipelined schedule
warp: Assigned warp index
phase: "prologue", "steady", or "epilogue"
needs_barrier_before: Whether a barrier synchronization is needed before this
barrier_id: Identifier for the barrier (if needed)
"""
def __init__(self, name: str, iteration: int, cycle: int, warp: int, phase: str):
self.name = name
self.iteration = iteration
self.cycle = cycle
self.warp = warp
self.phase = phase
self.needs_barrier_before: bool = False
self.barrier_id: Optional[str] = None
def __repr__(self):
barrier_str = f" [barrier:{self.barrier_id}]" if self.needs_barrier_before else ""
return f" cycle {self.cycle:3d}: {self.name}[iter={self.iteration}] @ warp {self.warp}{barrier_str}"
class PipelinedCode:
"""Complete pipelined code with prologue, steady state, and epilogue."""
def __init__(self):
self.prologue: List[PipelinedInstruction] = []
self.steady_state: List[PipelinedInstruction] = []
self.epilogue: List[PipelinedInstruction] = []
self.initiation_interval: int = 0
self.schedule_length: int = 0
self.num_copies: int = 0
def all_instructions(self) -> List[PipelinedInstruction]:
return self.prologue + self.steady_state + self.epilogue
def __repr__(self):
lines = [
f"PipelinedCode(I={self.initiation_interval}, L={self.schedule_length}, "
f"copies={self.num_copies})",
"",
"=== PROLOGUE ===",
]
for instr in self.prologue:
lines.append(str(instr))
lines.append("")
lines.append("=== STEADY STATE (loop body) ===")
for instr in self.steady_state:
lines.append(str(instr))
lines.append("")
lines.append("=== EPILOGUE ===")
for instr in self.epilogue:
lines.append(str(instr))
return "\n".join(lines)
def generate_pipelined_code(
graph: DependenceGraph,
result: JointSWPWSResult,
) -> PipelinedCode:
"""Generate pipelined and warp-annotated code from a Twill solution.
The code generation follows the standard modulo scheduling approach:
1. Prologue: iterations 0..n_copies-2, instructions before their steady-state position
2. Steady State: all n_copies overlapping, running in lockstep offset by I
3. Epilogue: iterations 1..n_copies-1, instructions after steady-state ends
Args:
graph: The dependence graph
result: The joint SWP+WS result from Twill
Returns:
PipelinedCode with prologue, steady state, and epilogue
"""
I = result.I
L = result.length
n_copies = result.num_copies
M = result.schedule
wa = result.warp_assignment
code = PipelinedCode()
code.initiation_interval = I
code.schedule_length = L
code.num_copies = n_copies
# Build the full straight-line schedule Q
# For each instruction v and each copy i:
# absolute_time(v, i) = M(v) + i * I
all_ops = []
for v in graph.V:
for i in range(n_copies):
abs_time = M[v.name] + i * I
if abs_time < L: # must finish within the schedule
warp = wa.warp_of(v.name)
all_ops.append((abs_time, v.name, i, warp))
# Sort by time
all_ops.sort(key=lambda x: (x[0], x[1]))
# Partition into prologue, steady state, epilogue
# Prologue: time < (n_copies - 1) * I
# Steady state: (n_copies - 1) * I <= time < n_copies * I (one full I window)
# Epilogue: time >= n_copies * I (but < L)
prologue_end = (n_copies - 1) * I if n_copies > 1 else 0
steady_end = n_copies * I if n_copies > 0 else I
for abs_time, name, iteration, warp in all_ops:
if n_copies <= 1:
# Only one copy -> everything is steady state
phase = "steady"
elif abs_time < prologue_end:
phase = "prologue"
elif abs_time < steady_end:
phase = "steady"
else:
phase = "epilogue"
instr = PipelinedInstruction(
name=name,
iteration=iteration,
cycle=abs_time,
warp=warp,
phase=phase,
)
if phase == "prologue":
code.prologue.append(instr)
elif phase == "steady":
code.steady_state.append(instr)
else:
code.epilogue.append(instr)
# Add barrier annotations for cross-warp dependencies
_annotate_barriers(graph, result, code)
return code
def _annotate_barriers(
graph: DependenceGraph,
result: JointSWPWSResult,
code: PipelinedCode,
):
"""Annotate instructions that need barrier synchronization.
A barrier is needed when:
1. There's a dependence edge (u, v, d, δ)
2. u and v are assigned to different warps
On Hopper/Blackwell, this uses mbarrier-based synchronization.
"""
wa = result.warp_assignment
# Build a map from (name, iteration) to instruction in the code
instr_map = {}
for instr in code.all_instructions():
instr_map[instr.name, instr.iteration] = instr
barrier_counter = 0
for edge in graph.E:
src_warp = wa.warp_of(edge.src)
dst_warp = wa.warp_of(edge.dst)
if src_warp != dst_warp:
# Cross-warp dependency -> needs barrier
for i in range(result.num_copies):
j = i + edge.iteration_delay
if (edge.dst, j) in instr_map:
target = instr_map[edge.dst, j]
target.needs_barrier_before = True
target.barrier_id = f"bar_{barrier_counter}"
barrier_counter += 1
def generate_pseudocode(
graph: DependenceGraph,
result: JointSWPWSResult,
) -> str:
"""Generate human-readable pseudocode from the Twill solution.
Returns a string of annotated pseudocode showing the pipelined schedule
with warp assignments and barriers.
"""
code = generate_pipelined_code(graph, result)
wa = result.warp_assignment
I = result.I
lines = []
lines.append(f"// Twill-generated pipelined schedule")
lines.append(f"// Initiation Interval I = {I}")
lines.append(f"// Schedule Length L = {result.length}")
lines.append(f"// Overlapping copies = {result.num_copies}")
lines.append(f"//")
# Warp assignment summary
lines.append(f"// Warp Assignment:")
for v in graph.V:
warp = wa.warp_of(v.name)
warp_label = wa.warp_names.get(warp, f"warp_{warp}")
var_lat = " [variable-latency]" if v.variable_latency else ""
lines.append(f"// {v.name} -> {warp_label}{var_lat}")
lines.append(f"")
# Prologue
if code.prologue:
lines.append(f"// ---- PROLOGUE ----")
for instr in code.prologue:
barrier = f" mbarrier.wait({instr.barrier_id});" if instr.needs_barrier_before else ""
lines.append(f"{barrier}")
lines.append(f"/* cycle {instr.cycle}, warp {instr.warp} */ "
f"{instr.name}_{instr.iteration} = {instr.name}(...);")
lines.append(f"")
# Steady state
lines.append(f"// ---- STEADY STATE (for i in range(N)): ----")
for instr in code.steady_state:
barrier = f" mbarrier.wait({instr.barrier_id});\n" if instr.needs_barrier_before else ""
lines.append(f"{barrier}/* cycle {instr.cycle}, warp {instr.warp} */ "
f"{instr.name} = {instr.name}(...); // iter offset={instr.iteration}")
lines.append(f"")
# Epilogue
if code.epilogue:
lines.append(f"// ---- EPILOGUE ----")
for instr in code.epilogue:
barrier = f" mbarrier.wait({instr.barrier_id});\n" if instr.needs_barrier_before else ""
lines.append(f"{barrier}/* cycle {instr.cycle}, warp {instr.warp} */ "
f"{instr.name}_{instr.iteration} = {instr.name}(...);")
return "\n".join(lines)
def generate_cuda_skeleton(
graph: DependenceGraph,
result: JointSWPWSResult,
) -> str:
"""Generate a CUDA C++ skeleton from the Twill solution.
This produces a template that shows the warp-specialized structure
with the correct barriers and pipeline stages.
"""
wa = result.warp_assignment
machine = graph.machine
code = generate_pipelined_code(graph, result)
lines = []
lines.append(f"// CUDA C++ skeleton generated by Twill")
lines.append(f"// Target: {machine.name}")
lines.append(f"// I={result.I}, L={result.length}, copies={result.num_copies}")
lines.append(f"")
lines.append(f"#include <cuda.h>")
lines.append(f"#include <cuda/barrier>")
lines.append(f"")
lines.append(f"__global__ void twill_kernel(...) {{")
lines.append(f" const int warp_id = threadIdx.x / 32;")
lines.append(f" ")
# Group by warp
warp_groups = {}
for w in range(machine.num_warps):
instrs = wa.instructions_on_warp(w)
if instrs:
warp_groups[w] = instrs
for w, instrs in warp_groups.items():
warp_label = wa.warp_names.get(w, f"warp {w}")
lines.append(f" if (warp_id == {w}) {{ // {warp_label}")
lines.append(f" // Operations: {instrs}")
# Show prologue operations for this warp
warp_prologue = [i for i in code.prologue if i.warp == w]
if warp_prologue:
lines.append(f" // --- Prologue ---")
for instr in warp_prologue:
if instr.needs_barrier_before:
lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});")
lines.append(f" {instr.name}(...); // iter {instr.iteration}")
# Show steady-state loop for this warp
warp_steady = [i for i in code.steady_state if i.warp == w]
if warp_steady:
lines.append(f" // --- Steady State ---")
lines.append(f" for (int i = 0; i < N; i++) {{")
for instr in warp_steady:
if instr.needs_barrier_before:
lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});")
lines.append(f" {instr.name}(...);")
lines.append(f" }}")
# Show epilogue for this warp
warp_epilogue = [i for i in code.epilogue if i.warp == w]
if warp_epilogue:
lines.append(f" // --- Epilogue ---")
for instr in warp_epilogue:
if instr.needs_barrier_before:
lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});")
lines.append(f" {instr.name}(...); // iter {instr.iteration}")
lines.append(f" }}")
lines.append(f" ")
lines.append(f"}}")
return "\n".join(lines)