""" Kernel descriptions: Pre-built dependence graphs for common GPU kernels. These correspond to the kernels evaluated in Section 6 of the paper. """ import numpy as np from twill.graph import ( DependenceGraph, Instruction, DependenceEdge, MachineDescription, hopper_machine, blackwell_machine, make_rrt, ) def flash_attention_forward_simplified( machine: MachineDescription = None, gemm_cycles: int = 1, exp_cycles: int = 1, ) -> DependenceGraph: """Simplified Flash Attention forward pass (Figure 1 of the paper). The simplified loop body: S = gemm(Q, K[i]) # uses TC P = exp(S) # uses EXP O += gemm(P, V[i]) # uses TC This is the running example from Section 3. On Hopper, GEMM and EXP on a tile have roughly the same cost. Args: machine: Target machine (default: Hopper) gemm_cycles: Normalized cycle count for GEMM exp_cycles: Normalized cycle count for EXP Returns: DependenceGraph ready for Twill """ if machine is None: machine = hopper_machine() graph = DependenceGraph(machine) num_fus = machine.num_functional_units tc_idx = machine.fu_index("TC") exp_idx = machine.fu_index("EXP") # S = gemm(Q, K[i]) - uses TC for gemm_cycles rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) graph.add_instruction(Instruction("S", rrt_S)) # P = exp(S) - uses EXP for exp_cycles rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus) graph.add_instruction(Instruction("P", rrt_P)) # O = gemm(P, V[i]) - uses TC for gemm_cycles rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) graph.add_instruction(Instruction("O", rrt_O)) # Dependencies (Figure 1c): # S -> P (d = gemm_cycles, δ = 0) - P must wait for GEMM to finish graph.add_edge(DependenceEdge("S", "P", delay=gemm_cycles, iteration_delay=0)) # P -> O (d = exp_cycles, δ = 0) - O must wait for EXP to finish graph.add_edge(DependenceEdge("P", "O", delay=exp_cycles, iteration_delay=0)) # O -> O (d = gemm_cycles, δ = 1) - loop-carried: O accumulates across iterations graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1)) return graph def flash_attention_forward_hopper( tma_cycles: int = 1, gemm_cycles: int = 2, exp_cycles: int = 2, ) -> DependenceGraph: """Full FMHA forward pass on Hopper (Section 6.2.1). Includes TMA loads for K and V tiles: K_load = tma_load(K[i]) # uses TMA (variable latency) V_load = tma_load(V[i]) # uses TMA (variable latency) S = wgmma(Q, K_load) # uses TC P = exp(S) # uses EXP O += wgmma(P, V_load) # uses TC FA3 discovered: SWP extracts S=gemm(Q,K[0]) into prologue, and ping-pong scheduling alternates EXP and TC across warp groups. """ machine = hopper_machine() graph = DependenceGraph(machine) num_fus = machine.num_functional_units tc_idx = machine.fu_index("TC") exp_idx = machine.fu_index("EXP") tma_idx = machine.fu_index("TMA") # TMA loads (variable latency, streaming) rrt_K_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) graph.add_instruction(Instruction( "K_load", rrt_K_load, variable_latency=True, streaming=True, memory_footprint={"SMEM": 128 * 128 * 2} # FP16 tile )) rrt_V_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) graph.add_instruction(Instruction( "V_load", rrt_V_load, variable_latency=True, streaming=True, memory_footprint={"SMEM": 128 * 128 * 2} )) # S = wgmma(Q, K_load) rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) graph.add_instruction(Instruction("S", rrt_S)) # P = exp(S) - softmax rescaling rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus) graph.add_instruction(Instruction("P", rrt_P)) # O += wgmma(P, V_load) - accumulate rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) graph.add_instruction(Instruction("O", rrt_O)) # Dependencies: # K_load -> S (TMA must complete before GEMM can consume K) graph.add_edge(DependenceEdge("K_load", "S", delay=tma_cycles, iteration_delay=0)) # V_load -> O (TMA must complete before GEMM can consume V) graph.add_edge(DependenceEdge("V_load", "O", delay=tma_cycles, iteration_delay=0)) # S -> P (GEMM result needed for softmax) graph.add_edge(DependenceEdge("S", "P", delay=gemm_cycles, iteration_delay=0)) # P -> O (softmax result needed for second GEMM) graph.add_edge(DependenceEdge("P", "O", delay=exp_cycles, iteration_delay=0)) # O -> O (loop carried: accumulation) graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1)) return graph def flash_attention_forward_blackwell( tma_cycles: int = 1, gemm_cycles: int = 1, # TC 2x faster on Blackwell exp_cycles: int = 2, # EXP unchanged tmem_cycles: int = 1, ) -> DependenceGraph: """FMHA forward pass on Blackwell (Section 6.2.2). Blackwell differences: - TC throughput 2x Hopper -> GEMM takes fewer relative cycles - Tensor Memory (TMEM) tier for TC accumulators - Explicit register <-> TMEM transfers needed - EXP/softmax becomes the bottleneck (doesn't scale as fast) FA4 strategy (rediscovered by Twill): - TMA loads on variable-latency warp (producer) - TC GEMMs on compute warps - Specific cross-warp communication for TMEM """ machine = blackwell_machine() graph = DependenceGraph(machine) num_fus = machine.num_functional_units tc_idx = machine.fu_index("TC") exp_idx = machine.fu_index("EXP") tma_idx = machine.fu_index("TMA") tmem_idx = machine.fu_index("TMEM") # TMA loads (variable latency, streaming) rrt_K_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) graph.add_instruction(Instruction( "K_load", rrt_K_load, variable_latency=True, streaming=True, memory_footprint={"SMEM": 128 * 128 * 2} )) rrt_V_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus) graph.add_instruction(Instruction( "V_load", rrt_V_load, variable_latency=True, streaming=True, memory_footprint={"SMEM": 128 * 128 * 2} )) # S = wgmma(Q, K_load) - output goes to TMEM rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) graph.add_instruction(Instruction("S", rrt_S, memory_footprint={"TMEM": 128 * 128 * 4})) # S_read: TMEM -> register transfer rrt_S_read = make_rrt(tmem_cycles, {tmem_idx: [1] * tmem_cycles}, num_fus) graph.add_instruction(Instruction("S_read", rrt_S_read)) # P = exp(S_read) - softmax in registers rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus) graph.add_instruction(Instruction("P", rrt_P)) # P_write: register -> TMEM for second GEMM rrt_P_write = make_rrt(tmem_cycles, {tmem_idx: [1] * tmem_cycles}, num_fus) graph.add_instruction(Instruction("P_write", rrt_P_write)) # O += wgmma(P_write, V_load) - accumulate in TMEM rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus) graph.add_instruction(Instruction("O", rrt_O)) # Dependencies: graph.add_edge(DependenceEdge("K_load", "S", delay=tma_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("S", "S_read", delay=gemm_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("S_read", "P", delay=tmem_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("P", "P_write", delay=exp_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("P_write", "O", delay=tmem_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("V_load", "O", delay=tma_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1)) return graph def simple_gemm_pipeline( machine: MachineDescription = None, load_cycles: int = 1, compute_cycles: int = 2, ) -> DependenceGraph: """Simple GEMM with load-compute overlap. The simplest pipelining case: A_load = tma_load(A[i]) B_load = tma_load(B[i]) C += gemm(A_load, B_load) """ if machine is None: machine = hopper_machine() graph = DependenceGraph(machine) num_fus = machine.num_functional_units tc_idx = machine.fu_index("TC") tma_idx = machine.fu_index("TMA") rrt_A = make_rrt(load_cycles, {tma_idx: [1] * load_cycles}, num_fus) graph.add_instruction(Instruction("A_load", rrt_A, variable_latency=True, streaming=True)) rrt_B = make_rrt(load_cycles, {tma_idx: [1] * load_cycles}, num_fus) graph.add_instruction(Instruction("B_load", rrt_B, variable_latency=True, streaming=True)) rrt_C = make_rrt(compute_cycles, {tc_idx: [1] * compute_cycles}, num_fus) graph.add_instruction(Instruction("C", rrt_C)) graph.add_edge(DependenceEdge("A_load", "C", delay=load_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("B_load", "C", delay=load_cycles, iteration_delay=0)) graph.add_edge(DependenceEdge("C", "C", delay=compute_cycles, iteration_delay=1)) return graph def custom_kernel( machine: MachineDescription, instructions: list, edges: list, ) -> DependenceGraph: """Build a custom kernel dependence graph. Args: machine: Target machine description instructions: List of dicts with keys: - name: str - cycles: int - fu: str (functional unit name) - variable_latency: bool (optional) - streaming: bool (optional) - memory: dict (optional, memory space -> bytes) edges: List of dicts with keys: - src: str - dst: str - delay: int - delta: int (iteration delay, default 0) Returns: DependenceGraph """ graph = DependenceGraph(machine) num_fus = machine.num_functional_units for instr_desc in instructions: fu_name = instr_desc["fu"] fu_idx = machine.fu_index(fu_name) cycles = instr_desc["cycles"] rrt = make_rrt(cycles, {fu_idx: [1] * cycles}, num_fus) graph.add_instruction(Instruction( name=instr_desc["name"], rrt=rrt, variable_latency=instr_desc.get("variable_latency", False), streaming=instr_desc.get("streaming", False), memory_footprint=instr_desc.get("memory", {}), )) for edge_desc in edges: graph.add_edge(DependenceEdge( src=edge_desc["src"], dst=edge_desc["dst"], delay=edge_desc["delay"], iteration_delay=edge_desc.get("delta", 0), )) return graph