""" Code Generation: Emit pipelined and warp-annotated code from Twill's solution. Based on Section 3.2 and 5 of the paper. Generates: 1. Prologue: primes the pipeline 2. Steady State: the main loop body 3. Epilogue: drains the pipeline Each instruction is annotated with its assigned warp and synchronization barriers. """ from typing import Dict, List, Optional, Tuple from twill.graph import DependenceGraph, Instruction, DependenceEdge from twill.smt_joint import JointSWPWSResult, WarpAssignment import math class PipelinedInstruction: """A single instruction in the pipelined code. Attributes: name: Original instruction name iteration: Which loop iteration this instance belongs to cycle: Clock cycle in the pipelined schedule warp: Assigned warp index phase: "prologue", "steady", or "epilogue" needs_barrier_before: Whether a barrier synchronization is needed before this barrier_id: Identifier for the barrier (if needed) """ def __init__(self, name: str, iteration: int, cycle: int, warp: int, phase: str): self.name = name self.iteration = iteration self.cycle = cycle self.warp = warp self.phase = phase self.needs_barrier_before: bool = False self.barrier_id: Optional[str] = None def __repr__(self): barrier_str = f" [barrier:{self.barrier_id}]" if self.needs_barrier_before else "" return f" cycle {self.cycle:3d}: {self.name}[iter={self.iteration}] @ warp {self.warp}{barrier_str}" class PipelinedCode: """Complete pipelined code with prologue, steady state, and epilogue.""" def __init__(self): self.prologue: List[PipelinedInstruction] = [] self.steady_state: List[PipelinedInstruction] = [] self.epilogue: List[PipelinedInstruction] = [] self.initiation_interval: int = 0 self.schedule_length: int = 0 self.num_copies: int = 0 def all_instructions(self) -> List[PipelinedInstruction]: return self.prologue + self.steady_state + self.epilogue def __repr__(self): lines = [ f"PipelinedCode(I={self.initiation_interval}, L={self.schedule_length}, " f"copies={self.num_copies})", "", "=== PROLOGUE ===", ] for instr in self.prologue: lines.append(str(instr)) lines.append("") lines.append("=== STEADY STATE (loop body) ===") for instr in self.steady_state: lines.append(str(instr)) lines.append("") lines.append("=== EPILOGUE ===") for instr in self.epilogue: lines.append(str(instr)) return "\n".join(lines) def generate_pipelined_code( graph: DependenceGraph, result: JointSWPWSResult, ) -> PipelinedCode: """Generate pipelined and warp-annotated code from a Twill solution. The code generation follows the standard modulo scheduling approach: 1. Prologue: iterations 0..n_copies-2, instructions before their steady-state position 2. Steady State: all n_copies overlapping, running in lockstep offset by I 3. Epilogue: iterations 1..n_copies-1, instructions after steady-state ends Args: graph: The dependence graph result: The joint SWP+WS result from Twill Returns: PipelinedCode with prologue, steady state, and epilogue """ I = result.I L = result.length n_copies = result.num_copies M = result.schedule wa = result.warp_assignment code = PipelinedCode() code.initiation_interval = I code.schedule_length = L code.num_copies = n_copies # Build the full straight-line schedule Q # For each instruction v and each copy i: # absolute_time(v, i) = M(v) + i * I all_ops = [] for v in graph.V: for i in range(n_copies): abs_time = M[v.name] + i * I if abs_time < L: # must finish within the schedule warp = wa.warp_of(v.name) all_ops.append((abs_time, v.name, i, warp)) # Sort by time all_ops.sort(key=lambda x: (x[0], x[1])) # Partition into prologue, steady state, epilogue # Prologue: time < (n_copies - 1) * I # Steady state: (n_copies - 1) * I <= time < n_copies * I (one full I window) # Epilogue: time >= n_copies * I (but < L) prologue_end = (n_copies - 1) * I if n_copies > 1 else 0 steady_end = n_copies * I if n_copies > 0 else I for abs_time, name, iteration, warp in all_ops: if n_copies <= 1: # Only one copy -> everything is steady state phase = "steady" elif abs_time < prologue_end: phase = "prologue" elif abs_time < steady_end: phase = "steady" else: phase = "epilogue" instr = PipelinedInstruction( name=name, iteration=iteration, cycle=abs_time, warp=warp, phase=phase, ) if phase == "prologue": code.prologue.append(instr) elif phase == "steady": code.steady_state.append(instr) else: code.epilogue.append(instr) # Add barrier annotations for cross-warp dependencies _annotate_barriers(graph, result, code) return code def _annotate_barriers( graph: DependenceGraph, result: JointSWPWSResult, code: PipelinedCode, ): """Annotate instructions that need barrier synchronization. A barrier is needed when: 1. There's a dependence edge (u, v, d, δ) 2. u and v are assigned to different warps On Hopper/Blackwell, this uses mbarrier-based synchronization. """ wa = result.warp_assignment # Build a map from (name, iteration) to instruction in the code instr_map = {} for instr in code.all_instructions(): instr_map[instr.name, instr.iteration] = instr barrier_counter = 0 for edge in graph.E: src_warp = wa.warp_of(edge.src) dst_warp = wa.warp_of(edge.dst) if src_warp != dst_warp: # Cross-warp dependency -> needs barrier for i in range(result.num_copies): j = i + edge.iteration_delay if (edge.dst, j) in instr_map: target = instr_map[edge.dst, j] target.needs_barrier_before = True target.barrier_id = f"bar_{barrier_counter}" barrier_counter += 1 def generate_pseudocode( graph: DependenceGraph, result: JointSWPWSResult, ) -> str: """Generate human-readable pseudocode from the Twill solution. Returns a string of annotated pseudocode showing the pipelined schedule with warp assignments and barriers. """ code = generate_pipelined_code(graph, result) wa = result.warp_assignment I = result.I lines = [] lines.append(f"// Twill-generated pipelined schedule") lines.append(f"// Initiation Interval I = {I}") lines.append(f"// Schedule Length L = {result.length}") lines.append(f"// Overlapping copies = {result.num_copies}") lines.append(f"//") # Warp assignment summary lines.append(f"// Warp Assignment:") for v in graph.V: warp = wa.warp_of(v.name) warp_label = wa.warp_names.get(warp, f"warp_{warp}") var_lat = " [variable-latency]" if v.variable_latency else "" lines.append(f"// {v.name} -> {warp_label}{var_lat}") lines.append(f"") # Prologue if code.prologue: lines.append(f"// ---- PROLOGUE ----") for instr in code.prologue: barrier = f" mbarrier.wait({instr.barrier_id});" if instr.needs_barrier_before else "" lines.append(f"{barrier}") lines.append(f"/* cycle {instr.cycle}, warp {instr.warp} */ " f"{instr.name}_{instr.iteration} = {instr.name}(...);") lines.append(f"") # Steady state lines.append(f"// ---- STEADY STATE (for i in range(N)): ----") for instr in code.steady_state: barrier = f" mbarrier.wait({instr.barrier_id});\n" if instr.needs_barrier_before else "" lines.append(f"{barrier}/* cycle {instr.cycle}, warp {instr.warp} */ " f"{instr.name} = {instr.name}(...); // iter offset={instr.iteration}") lines.append(f"") # Epilogue if code.epilogue: lines.append(f"// ---- EPILOGUE ----") for instr in code.epilogue: barrier = f" mbarrier.wait({instr.barrier_id});\n" if instr.needs_barrier_before else "" lines.append(f"{barrier}/* cycle {instr.cycle}, warp {instr.warp} */ " f"{instr.name}_{instr.iteration} = {instr.name}(...);") return "\n".join(lines) def generate_cuda_skeleton( graph: DependenceGraph, result: JointSWPWSResult, ) -> str: """Generate a CUDA C++ skeleton from the Twill solution. This produces a template that shows the warp-specialized structure with the correct barriers and pipeline stages. """ wa = result.warp_assignment machine = graph.machine code = generate_pipelined_code(graph, result) lines = [] lines.append(f"// CUDA C++ skeleton generated by Twill") lines.append(f"// Target: {machine.name}") lines.append(f"// I={result.I}, L={result.length}, copies={result.num_copies}") lines.append(f"") lines.append(f"#include ") lines.append(f"#include ") lines.append(f"") lines.append(f"__global__ void twill_kernel(...) {{") lines.append(f" const int warp_id = threadIdx.x / 32;") lines.append(f" ") # Group by warp warp_groups = {} for w in range(machine.num_warps): instrs = wa.instructions_on_warp(w) if instrs: warp_groups[w] = instrs for w, instrs in warp_groups.items(): warp_label = wa.warp_names.get(w, f"warp {w}") lines.append(f" if (warp_id == {w}) {{ // {warp_label}") lines.append(f" // Operations: {instrs}") # Show prologue operations for this warp warp_prologue = [i for i in code.prologue if i.warp == w] if warp_prologue: lines.append(f" // --- Prologue ---") for instr in warp_prologue: if instr.needs_barrier_before: lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});") lines.append(f" {instr.name}(...); // iter {instr.iteration}") # Show steady-state loop for this warp warp_steady = [i for i in code.steady_state if i.warp == w] if warp_steady: lines.append(f" // --- Steady State ---") lines.append(f" for (int i = 0; i < N; i++) {{") for instr in warp_steady: if instr.needs_barrier_before: lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});") lines.append(f" {instr.name}(...);") lines.append(f" }}") # Show epilogue for this warp warp_epilogue = [i for i in code.epilogue if i.warp == w] if warp_epilogue: lines.append(f" // --- Epilogue ---") for instr in warp_epilogue: if instr.needs_barrier_before: lines.append(f" mbarrier::arrive_and_wait({instr.barrier_id});") lines.append(f" {instr.name}(...); // iter {instr.iteration}") lines.append(f" }}") lines.append(f" ") lines.append(f"}}") return "\n".join(lines)