twill-swp-ws / twill /kernels.py
AshenNav's picture
Upload twill/kernels.py with huggingface_hub
36457f1 verified
"""
Kernel descriptions: Pre-built dependence graphs for common GPU kernels.
These correspond to the kernels evaluated in Section 6 of the paper.
"""
import numpy as np
from twill.graph import (
DependenceGraph, Instruction, DependenceEdge, MachineDescription,
hopper_machine, blackwell_machine, make_rrt,
)
def flash_attention_forward_simplified(
machine: MachineDescription = None,
gemm_cycles: int = 1,
exp_cycles: int = 1,
) -> DependenceGraph:
"""Simplified Flash Attention forward pass (Figure 1 of the paper).
The simplified loop body:
S = gemm(Q, K[i]) # uses TC
P = exp(S) # uses EXP
O += gemm(P, V[i]) # uses TC
This is the running example from Section 3.
On Hopper, GEMM and EXP on a tile have roughly the same cost.
Args:
machine: Target machine (default: Hopper)
gemm_cycles: Normalized cycle count for GEMM
exp_cycles: Normalized cycle count for EXP
Returns:
DependenceGraph ready for Twill
"""
if machine is None:
machine = hopper_machine()
graph = DependenceGraph(machine)
num_fus = machine.num_functional_units
tc_idx = machine.fu_index("TC")
exp_idx = machine.fu_index("EXP")
# S = gemm(Q, K[i]) - uses TC for gemm_cycles
rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus)
graph.add_instruction(Instruction("S", rrt_S))
# P = exp(S) - uses EXP for exp_cycles
rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus)
graph.add_instruction(Instruction("P", rrt_P))
# O = gemm(P, V[i]) - uses TC for gemm_cycles
rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus)
graph.add_instruction(Instruction("O", rrt_O))
# Dependencies (Figure 1c):
# S -> P (d = gemm_cycles, δ = 0) - P must wait for GEMM to finish
graph.add_edge(DependenceEdge("S", "P", delay=gemm_cycles, iteration_delay=0))
# P -> O (d = exp_cycles, δ = 0) - O must wait for EXP to finish
graph.add_edge(DependenceEdge("P", "O", delay=exp_cycles, iteration_delay=0))
# O -> O (d = gemm_cycles, δ = 1) - loop-carried: O accumulates across iterations
graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1))
return graph
def flash_attention_forward_hopper(
tma_cycles: int = 1,
gemm_cycles: int = 2,
exp_cycles: int = 2,
) -> DependenceGraph:
"""Full FMHA forward pass on Hopper (Section 6.2.1).
Includes TMA loads for K and V tiles:
K_load = tma_load(K[i]) # uses TMA (variable latency)
V_load = tma_load(V[i]) # uses TMA (variable latency)
S = wgmma(Q, K_load) # uses TC
P = exp(S) # uses EXP
O += wgmma(P, V_load) # uses TC
FA3 discovered: SWP extracts S=gemm(Q,K[0]) into prologue,
and ping-pong scheduling alternates EXP and TC across warp groups.
"""
machine = hopper_machine()
graph = DependenceGraph(machine)
num_fus = machine.num_functional_units
tc_idx = machine.fu_index("TC")
exp_idx = machine.fu_index("EXP")
tma_idx = machine.fu_index("TMA")
# TMA loads (variable latency, streaming)
rrt_K_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus)
graph.add_instruction(Instruction(
"K_load", rrt_K_load, variable_latency=True, streaming=True,
memory_footprint={"SMEM": 128 * 128 * 2} # FP16 tile
))
rrt_V_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus)
graph.add_instruction(Instruction(
"V_load", rrt_V_load, variable_latency=True, streaming=True,
memory_footprint={"SMEM": 128 * 128 * 2}
))
# S = wgmma(Q, K_load)
rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus)
graph.add_instruction(Instruction("S", rrt_S))
# P = exp(S) - softmax rescaling
rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus)
graph.add_instruction(Instruction("P", rrt_P))
# O += wgmma(P, V_load) - accumulate
rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus)
graph.add_instruction(Instruction("O", rrt_O))
# Dependencies:
# K_load -> S (TMA must complete before GEMM can consume K)
graph.add_edge(DependenceEdge("K_load", "S", delay=tma_cycles, iteration_delay=0))
# V_load -> O (TMA must complete before GEMM can consume V)
graph.add_edge(DependenceEdge("V_load", "O", delay=tma_cycles, iteration_delay=0))
# S -> P (GEMM result needed for softmax)
graph.add_edge(DependenceEdge("S", "P", delay=gemm_cycles, iteration_delay=0))
# P -> O (softmax result needed for second GEMM)
graph.add_edge(DependenceEdge("P", "O", delay=exp_cycles, iteration_delay=0))
# O -> O (loop carried: accumulation)
graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1))
return graph
def flash_attention_forward_blackwell(
tma_cycles: int = 1,
gemm_cycles: int = 1, # TC 2x faster on Blackwell
exp_cycles: int = 2, # EXP unchanged
tmem_cycles: int = 1,
) -> DependenceGraph:
"""FMHA forward pass on Blackwell (Section 6.2.2).
Blackwell differences:
- TC throughput 2x Hopper -> GEMM takes fewer relative cycles
- Tensor Memory (TMEM) tier for TC accumulators
- Explicit register <-> TMEM transfers needed
- EXP/softmax becomes the bottleneck (doesn't scale as fast)
FA4 strategy (rediscovered by Twill):
- TMA loads on variable-latency warp (producer)
- TC GEMMs on compute warps
- Specific cross-warp communication for TMEM
"""
machine = blackwell_machine()
graph = DependenceGraph(machine)
num_fus = machine.num_functional_units
tc_idx = machine.fu_index("TC")
exp_idx = machine.fu_index("EXP")
tma_idx = machine.fu_index("TMA")
tmem_idx = machine.fu_index("TMEM")
# TMA loads (variable latency, streaming)
rrt_K_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus)
graph.add_instruction(Instruction(
"K_load", rrt_K_load, variable_latency=True, streaming=True,
memory_footprint={"SMEM": 128 * 128 * 2}
))
rrt_V_load = make_rrt(tma_cycles, {tma_idx: [1] * tma_cycles}, num_fus)
graph.add_instruction(Instruction(
"V_load", rrt_V_load, variable_latency=True, streaming=True,
memory_footprint={"SMEM": 128 * 128 * 2}
))
# S = wgmma(Q, K_load) - output goes to TMEM
rrt_S = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus)
graph.add_instruction(Instruction("S", rrt_S, memory_footprint={"TMEM": 128 * 128 * 4}))
# S_read: TMEM -> register transfer
rrt_S_read = make_rrt(tmem_cycles, {tmem_idx: [1] * tmem_cycles}, num_fus)
graph.add_instruction(Instruction("S_read", rrt_S_read))
# P = exp(S_read) - softmax in registers
rrt_P = make_rrt(exp_cycles, {exp_idx: [1] * exp_cycles}, num_fus)
graph.add_instruction(Instruction("P", rrt_P))
# P_write: register -> TMEM for second GEMM
rrt_P_write = make_rrt(tmem_cycles, {tmem_idx: [1] * tmem_cycles}, num_fus)
graph.add_instruction(Instruction("P_write", rrt_P_write))
# O += wgmma(P_write, V_load) - accumulate in TMEM
rrt_O = make_rrt(gemm_cycles, {tc_idx: [1] * gemm_cycles}, num_fus)
graph.add_instruction(Instruction("O", rrt_O))
# Dependencies:
graph.add_edge(DependenceEdge("K_load", "S", delay=tma_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("S", "S_read", delay=gemm_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("S_read", "P", delay=tmem_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("P", "P_write", delay=exp_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("P_write", "O", delay=tmem_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("V_load", "O", delay=tma_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("O", "O", delay=gemm_cycles, iteration_delay=1))
return graph
def simple_gemm_pipeline(
machine: MachineDescription = None,
load_cycles: int = 1,
compute_cycles: int = 2,
) -> DependenceGraph:
"""Simple GEMM with load-compute overlap.
The simplest pipelining case:
A_load = tma_load(A[i])
B_load = tma_load(B[i])
C += gemm(A_load, B_load)
"""
if machine is None:
machine = hopper_machine()
graph = DependenceGraph(machine)
num_fus = machine.num_functional_units
tc_idx = machine.fu_index("TC")
tma_idx = machine.fu_index("TMA")
rrt_A = make_rrt(load_cycles, {tma_idx: [1] * load_cycles}, num_fus)
graph.add_instruction(Instruction("A_load", rrt_A, variable_latency=True, streaming=True))
rrt_B = make_rrt(load_cycles, {tma_idx: [1] * load_cycles}, num_fus)
graph.add_instruction(Instruction("B_load", rrt_B, variable_latency=True, streaming=True))
rrt_C = make_rrt(compute_cycles, {tc_idx: [1] * compute_cycles}, num_fus)
graph.add_instruction(Instruction("C", rrt_C))
graph.add_edge(DependenceEdge("A_load", "C", delay=load_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("B_load", "C", delay=load_cycles, iteration_delay=0))
graph.add_edge(DependenceEdge("C", "C", delay=compute_cycles, iteration_delay=1))
return graph
def custom_kernel(
machine: MachineDescription,
instructions: list,
edges: list,
) -> DependenceGraph:
"""Build a custom kernel dependence graph.
Args:
machine: Target machine description
instructions: List of dicts with keys:
- name: str
- cycles: int
- fu: str (functional unit name)
- variable_latency: bool (optional)
- streaming: bool (optional)
- memory: dict (optional, memory space -> bytes)
edges: List of dicts with keys:
- src: str
- dst: str
- delay: int
- delta: int (iteration delay, default 0)
Returns:
DependenceGraph
"""
graph = DependenceGraph(machine)
num_fus = machine.num_functional_units
for instr_desc in instructions:
fu_name = instr_desc["fu"]
fu_idx = machine.fu_index(fu_name)
cycles = instr_desc["cycles"]
rrt = make_rrt(cycles, {fu_idx: [1] * cycles}, num_fus)
graph.add_instruction(Instruction(
name=instr_desc["name"],
rrt=rrt,
variable_latency=instr_desc.get("variable_latency", False),
streaming=instr_desc.get("streaming", False),
memory_footprint=instr_desc.get("memory", {}),
))
for edge_desc in edges:
graph.add_edge(DependenceEdge(
src=edge_desc["src"],
dst=edge_desc["dst"],
delay=edge_desc["delay"],
iteration_delay=edge_desc.get("delta", 0),
))
return graph