| """ |
| Kernel descriptions: Pre-built dependence graphs for common GPU kernels. |
| |
| These correspond to the kernels evaluated in Section 6 of the paper. |
| """ |
|
|
| import numpy as np |
| from twill.graph import ( |
| DependenceGraph, Instruction, DependenceEdge, MachineDescription, |
| hopper_machine, blackwell_machine, make_rrt, |
| ) |
|
|
|
|
| def flash_attention_forward_simplified( |
| machine: MachineDescription = None, |
| gemm_cycles: int = 1, |
| exp_cycles: int = 1, |
| ) -> DependenceGraph: |
| """Simplified Flash Attention forward pass (Figure 1 of the paper). |
| |
| The simplified loop body: |
| S = gemm(Q, K[i]) # uses TC |
| P = exp(S) # uses EXP |
| O += gemm(P, V[i]) # uses TC |
| |
| This is the running example from Section 3. |
| On Hopper, GEMM and EXP on a tile have roughly the same cost. |
| |
| Args: |
| machine: Target machine (default: Hopper) |
| gemm_cycles: Normalized cycle count for GEMM |
| exp_cycles: Normalized cycle count for EXP |
| |
| Returns: |
| DependenceGraph ready for Twill |
| """ |
| if machine is None: |
| machine = hopper_machine() |
|
|
| graph = DependenceGraph(machine) |
| num_fus = machine.num_functional_units |
| tc_idx = machine.fu_index("TC") |
| exp_idx = machine.fu_index("EXP") |
|
|
| |
| rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) |
| graph.add_instruction(Instruction("S", rrt_S)) |
|
|
| |
| rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus) |
| graph.add_instruction(Instruction("P", rrt_P)) |
|
|
| |
| rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) |
| graph.add_instruction(Instruction("O", rrt_O)) |
|
|
| |
| |
| graph.add_edge(DependenceEdge("S", "P", delay=gemm_cycles, iteration_delay=0)) |
|
|
| |
| graph.add_edge(DependenceEdge("P", "O", delay=exp_cycles, iteration_delay=0)) |
|
|
| |
| graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1)) |
|
|
| return graph |
|
|
|
|
| def flash_attention_forward_hopper( |
| tma_cycles: int = 1, |
| gemm_cycles: int = 2, |
| exp_cycles: int = 2, |
| ) -> DependenceGraph: |
| """Full FMHA forward pass on Hopper (Section 6.2.1). |
| |
| Includes TMA loads for K and V tiles: |
| K_load = tma_load(K[i]) # uses TMA (variable latency) |
| V_load = tma_load(V[i]) # uses TMA (variable latency) |
| S = wgmma(Q, K_load) # uses TC |
| P = exp(S) # uses EXP |
| O += wgmma(P, V_load) # uses TC |
| |
| FA3 discovered: SWP extracts S=gemm(Q,K[0]) into prologue, |
| and ping-pong scheduling alternates EXP and TC across warp groups. |
| """ |
| machine = hopper_machine() |
| graph = DependenceGraph(machine) |
| num_fus = machine.num_functional_units |
| tc_idx = machine.fu_index("TC") |
| exp_idx = machine.fu_index("EXP") |
| tma_idx = machine.fu_index("TMA") |
|
|
| |
| rrt_K_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) |
| graph.add_instruction(Instruction( |
| "K_load", rrt_K_load, variable_latency=True, streaming=True, |
| memory_footprint={"SMEM": 128 * 128 * 2} |
| )) |
|
|
| rrt_V_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) |
| graph.add_instruction(Instruction( |
| "V_load", rrt_V_load, variable_latency=True, streaming=True, |
| memory_footprint={"SMEM": 128 * 128 * 2} |
| )) |
|
|
| |
| rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) |
| graph.add_instruction(Instruction("S", rrt_S)) |
|
|
| |
| rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus) |
| graph.add_instruction(Instruction("P", rrt_P)) |
|
|
| |
| rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) |
| graph.add_instruction(Instruction("O", rrt_O)) |
|
|
| |
| |
| graph.add_edge(DependenceEdge("K_load", "S", delay=tma_cycles, iteration_delay=0)) |
| |
| graph.add_edge(DependenceEdge("V_load", "O", delay=tma_cycles, iteration_delay=0)) |
| |
| graph.add_edge(DependenceEdge("S", "P", delay=gemm_cycles, iteration_delay=0)) |
| |
| graph.add_edge(DependenceEdge("P", "O", delay=exp_cycles, iteration_delay=0)) |
| |
| graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1)) |
|
|
| return graph |
|
|
|
|
| def flash_attention_forward_blackwell( |
| tma_cycles: int = 1, |
| gemm_cycles: int = 1, |
| exp_cycles: int = 2, |
| tmem_cycles: int = 1, |
| ) -> DependenceGraph: |
| """FMHA forward pass on Blackwell (Section 6.2.2). |
| |
| Blackwell differences: |
| - TC throughput 2x Hopper -> GEMM takes fewer relative cycles |
| - Tensor Memory (TMEM) tier for TC accumulators |
| - Explicit register <-> TMEM transfers needed |
| - EXP/softmax becomes the bottleneck (doesn't scale as fast) |
| |
| FA4 strategy (rediscovered by Twill): |
| - TMA loads on variable-latency warp (producer) |
| - TC GEMMs on compute warps |
| - Specific cross-warp communication for TMEM |
| """ |
| machine = blackwell_machine() |
| graph = DependenceGraph(machine) |
| num_fus = machine.num_functional_units |
| tc_idx = machine.fu_index("TC") |
| exp_idx = machine.fu_index("EXP") |
| tma_idx = machine.fu_index("TMA") |
| tmem_idx = machine.fu_index("TMEM") |
|
|
| |
| rrt_K_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) |
| graph.add_instruction(Instruction( |
| "K_load", rrt_K_load, variable_latency=True, streaming=True, |
| memory_footprint={"SMEM": 128 * 128 * 2} |
| )) |
|
|
| rrt_V_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) |
| graph.add_instruction(Instruction( |
| "V_load", rrt_V_load, variable_latency=True, streaming=True, |
| memory_footprint={"SMEM": 128 * 128 * 2} |
| )) |
|
|
| |
| rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) |
| graph.add_instruction(Instruction("S", rrt_S, memory_footprint={"TMEM": 128 * 128 * 4})) |
|
|
| |
| rrt_S_read = make_rrt(tmem_cycles, {tmem_idx: [1] * tmem_cycles}, num_fus) |
| graph.add_instruction(Instruction("S_read", rrt_S_read)) |
|
|
| |
| rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus) |
| graph.add_instruction(Instruction("P", rrt_P)) |
|
|
| |
| rrt_P_write = make_rrt(tmem_cycles, {tmem_idx: [1] * tmem_cycles}, num_fus) |
| graph.add_instruction(Instruction("P_write", rrt_P_write)) |
|
|
| |
| rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) |
| graph.add_instruction(Instruction("O", rrt_O)) |
|
|
| |
| graph.add_edge(DependenceEdge("K_load", "S", delay=tma_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("S", "S_read", delay=gemm_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("S_read", "P", delay=tmem_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("P", "P_write", delay=exp_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("P_write", "O", delay=tmem_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("V_load", "O", delay=tma_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1)) |
|
|
| return graph |
|
|
|
|
| def simple_gemm_pipeline( |
| machine: MachineDescription = None, |
| load_cycles: int = 1, |
| compute_cycles: int = 2, |
| ) -> DependenceGraph: |
| """Simple GEMM with load-compute overlap. |
| |
| The simplest pipelining case: |
| A_load = tma_load(A[i]) |
| B_load = tma_load(B[i]) |
| C += gemm(A_load, B_load) |
| """ |
| if machine is None: |
| machine = hopper_machine() |
|
|
| graph = DependenceGraph(machine) |
| num_fus = machine.num_functional_units |
| tc_idx = machine.fu_index("TC") |
| tma_idx = machine.fu_index("TMA") |
|
|
| rrt_A = make_rrt(load_cycles, {tma_idx: [1] * load_cycles}, num_fus) |
| graph.add_instruction(Instruction("A_load", rrt_A, variable_latency=True, streaming=True)) |
|
|
| rrt_B = make_rrt(load_cycles, {tma_idx: [1] * load_cycles}, num_fus) |
| graph.add_instruction(Instruction("B_load", rrt_B, variable_latency=True, streaming=True)) |
|
|
| rrt_C = make_rrt(compute_cycles, {tc_idx: [1] * compute_cycles}, num_fus) |
| graph.add_instruction(Instruction("C", rrt_C)) |
|
|
| graph.add_edge(DependenceEdge("A_load", "C", delay=load_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("B_load", "C", delay=load_cycles, iteration_delay=0)) |
| graph.add_edge(DependenceEdge("C", "C", delay=compute_cycles, iteration_delay=1)) |
|
|
| return graph |
|
|
|
|
| def custom_kernel( |
| machine: MachineDescription, |
| instructions: list, |
| edges: list, |
| ) -> DependenceGraph: |
| """Build a custom kernel dependence graph. |
| |
| Args: |
| machine: Target machine description |
| instructions: List of dicts with keys: |
| - name: str |
| - cycles: int |
| - fu: str (functional unit name) |
| - variable_latency: bool (optional) |
| - streaming: bool (optional) |
| - memory: dict (optional, memory space -> bytes) |
| edges: List of dicts with keys: |
| - src: str |
| - dst: str |
| - delay: int |
| - delta: int (iteration delay, default 0) |
| |
| Returns: |
| DependenceGraph |
| """ |
| graph = DependenceGraph(machine) |
| num_fus = machine.num_functional_units |
|
|
| for instr_desc in instructions: |
| fu_name = instr_desc["fu"] |
| fu_idx = machine.fu_index(fu_name) |
| cycles = instr_desc["cycles"] |
| rrt = make_rrt(cycles, {fu_idx: [1] * cycles}, num_fus) |
| |
| graph.add_instruction(Instruction( |
| name=instr_desc["name"], |
| rrt=rrt, |
| variable_latency=instr_desc.get("variable_latency", False), |
| streaming=instr_desc.get("streaming", False), |
| memory_footprint=instr_desc.get("memory", {}), |
| )) |
|
|
| for edge_desc in edges: |
| graph.add_edge(DependenceEdge( |
| src=edge_desc["src"], |
| dst=edge_desc["dst"], |
| delay=edge_desc["delay"], |
| iteration_delay=edge_desc.get("delta", 0), |
| )) |
|
|
| return graph |
|
|