AshenNav commited on
Commit
9e6c8b9
Β·
verified Β·
1 Parent(s): 1717684

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +103 -128
README.md CHANGED
@@ -1,185 +1,150 @@
1
- # Twill: Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs
2
 
3
- Implementation of **["Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs"](https://arxiv.org/abs/2512.18134)** by Rupanshu Soi et al. (NVIDIA Research, 2024).
4
 
5
- ## What is Twill?
6
 
7
- Twill is the first system that automatically derives **provably optimal** Software Pipelining (SWP) + Warp Specialization (WS) schedules for tensor core GPU kernels. It formulates the joint optimization as a constraint satisfaction problem solved by off-the-shelf ILP and SMT solvers.
8
 
9
- **Key result:** Twill automatically rediscovers the expert-designed schedules of FlashAttention-3 (Hopper) and FlashAttention-4 (Blackwell) β€” proving these human-designed schedules are optimal.
10
 
11
- ## Architecture
 
 
 
 
 
 
12
 
13
- The solver has two phases following Algorithm 1 from the paper:
14
 
15
  ```
16
- Phase 1: ILP Modulo Scheduling (CBC solver via PuLP)
17
- β†’ Finds optimal initiation interval I and initial schedule M
18
-
19
- Phase 2: SMT Joint SWP + WS (Z3 solver)
20
- β†’ Finds optimal schedule M* and warp assignment A*
21
- β†’ Encodes constraints from Figures 4, 5, 6 of the paper
22
-
23
- Cost Normalization (Section 5.2):
24
- β†’ Reduces cycle counts while preserving ratios
25
- β†’ Makes the ILP/SMT problems tractable for real GPU cycle counts (~1000 cycles)
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ```
27
 
28
  ## Quick Start
29
 
30
  ```bash
31
- pip install pulp z3-solver numpy matplotlib
32
  ```
33
 
 
34
  ```python
35
  from twill.kernels import flash_attention_forward_simplified
36
  from twill.twill_solver import twill_solve
37
- from twill.visualization import visualize_schedule
38
 
39
- # Build the simplified Flash Attention dependence graph (Figure 1)
40
  graph = flash_attention_forward_simplified()
41
-
42
- # Run the full Twill solver
43
  result = twill_solve(graph, max_I=5, verbose=True)
44
-
45
- # Visualize the result
46
- print(visualize_schedule(graph, result.joint_result))
47
  ```
48
 
49
- Output:
50
- ```
51
- SOLUTION FOUND in 0.12s
52
- Initiation Interval I = 2 ← optimal!
53
- Schedule Length L = 4
54
- Overlapping copies = 2
55
- Schedule: S@0, P@2, O@3 ← S extracted into prologue
56
- Warp Assignment: all on warp 0 (no variable-latency ops)
57
- ```
58
-
59
- ## Pre-built Kernel Descriptions
60
-
61
- | Kernel | Section | Architecture | Key Result |
62
- |--------|---------|-------------|------------|
63
- | `flash_attention_forward_simplified()` | Β§3 (Figure 1) | Hopper | I=2, SWP extracts S into prologue |
64
- | `flash_attention_forward_hopper()` | Β§6.2.1 | Hopper (H100) | Rediscovers FA3 pipeline + ping-pong |
65
- | `flash_attention_forward_blackwell()` | Β§6.2.2 | Blackwell (B200) | Rediscovers FA4 strategy |
66
- | `simple_gemm_pipeline()` | β€” | Hopper | Load-compute overlap, TMA on producer warp |
67
-
68
- ## Custom Kernels
69
-
70
- Define your own kernel dependence graph:
71
-
72
  ```python
73
- from twill.graph import hopper_machine
74
- from twill.kernels import custom_kernel
75
- from twill.twill_solver import twill_solve
76
 
77
- machine = hopper_machine()
78
- graph = custom_kernel(
79
- machine=machine,
80
- instructions=[
81
- {"name": "load_A", "cycles": 1, "fu": "TMA", "variable_latency": True, "streaming": True},
82
- {"name": "load_B", "cycles": 1, "fu": "TMA", "variable_latency": True, "streaming": True},
83
- {"name": "gemm", "cycles": 2, "fu": "TC"},
84
- {"name": "relu", "cycles": 1, "fu": "EXP"},
85
- ],
86
- edges=[
87
- {"src": "load_A", "dst": "gemm", "delay": 1},
88
- {"src": "load_B", "dst": "gemm", "delay": 1},
89
- {"src": "gemm", "dst": "relu", "delay": 2},
90
- {"src": "relu", "dst": "relu", "delay": 1, "delta": 1}, # loop-carried
91
- ],
92
- )
93
-
94
- result = twill_solve(graph, verbose=True)
95
  ```
96
 
97
- ## Code Generation
98
-
99
- Twill generates three output formats:
100
-
101
  ```python
102
- from twill.codegen import generate_pseudocode, generate_cuda_skeleton, generate_pipelined_code
103
 
104
- # Human-readable pseudocode with warp annotations
105
- print(generate_pseudocode(graph, result.joint_result))
 
 
106
 
107
- # CUDA C++ skeleton with warp-specialized structure
108
- print(generate_cuda_skeleton(graph, result.joint_result))
109
 
110
- # Structured PipelinedCode object for further processing
111
- code = generate_pipelined_code(graph, result.joint_result)
112
- ```
 
 
 
113
 
114
  ## Module Structure
115
 
116
  ```
117
  twill/
118
- β”œβ”€β”€ __init__.py # Package exports
119
  β”œβ”€β”€ graph.py # DependenceGraph, Instruction, RRT, MachineDescription
120
- β”œβ”€β”€ cost_normalization.py # Section 5.2: ILP-based cycle count normalization
121
- β”œβ”€β”€ modulo_scheduler.py # Phase 1: ILP modulo scheduling (CBC solver)
122
- β”œβ”€β”€ smt_joint.py # Phase 2: SMT joint SWP+WS (Z3 solver)
123
- β”œβ”€β”€ twill_solver.py # Algorithm 1: Main search procedure
 
124
  β”œβ”€β”€ codegen.py # Code generation (pseudocode, CUDA skeleton)
125
  β”œβ”€β”€ visualization.py # Schedule visualization (text + matplotlib)
126
  └── kernels.py # Pre-built kernel descriptions (FMHA, GEMM)
127
  ```
128
 
129
- ## Constraint Groups (from the paper)
130
-
131
- ### Figure 4: Modulo Scheduling Constraints
132
- - **Uniqueness**: Each operation scheduled exactly once per iteration copy
133
- - **Consistency**: Modulo structure preserved across copies (offset by I)
134
- - **Completion**: Operations must finish before end of schedule
135
- - **Dependence**: Data dependencies respected across iterations
136
- - **Capacity**: Functional unit capacities not exceeded
137
 
138
- ### Figure 5: Memory Allocation Constraints
139
- - **Memory Capacity**: Working set fits in on-chip memory (SMEM, TMEM, registers)
140
- - **Liveness**: SSA-based backward dataflow for variable lifetimes
141
 
142
- ### Figure 6: Warp Assignment Constraints
143
- - **WarpUniqueness**: Each instruction assigned to exactly one warp
144
- - **VariableLatency**: Variable-latency ops go to dedicated producer warp
145
- - **WarpCapacity**: Per-warp resource budget respected
146
 
147
- ## Solvers Used
 
 
 
 
 
 
148
 
149
- | Component | Solver | Theory | Paper Reference |
150
- |-----------|--------|--------|-----------------|
151
- | Cost Normalization | CBC (PuLP) | ILP | Section 5.2 (paper uses SCIP) |
152
- | Modulo Scheduling | CBC (PuLP) | ILP | Section 3.1, Stoutchinin et al. |
153
- | Joint SWP + WS | Z3 | QFLIA (SMT) | Section 4 (paper uses Yices2) |
 
154
 
155
  ## Tests
156
 
157
  ```bash
158
- python test_twill.py
159
- ```
160
-
161
- ```
162
- βœ“ PASS Cost Normalization
163
- βœ“ PASS Modulo Scheduling Only
164
- βœ“ PASS Simplified FA (Figure 1)
165
- βœ“ PASS Simple GEMM
166
- βœ“ PASS Hopper FMHA Forward
167
- βœ“ PASS Blackwell FMHA Forward
168
-
169
- Passed: 6/6
170
- Total time: ~5s
171
  ```
172
 
173
- ## Limitations
174
 
175
- Following the paper (Section 5.4):
176
- - Only supports singly-nested loops without control flow
177
- - Tile size is not automatically determined (external concern)
178
- - Code generation produces skeletons, not fully compilable CUDA
179
- (the paper notes that even their implementation required "hand-compilation" to CUDA C++
180
- because Triton made incorrect decisions during code generation)
181
 
182
- ## Citation
183
 
184
  ```bibtex
185
  @article{soi2024twill,
@@ -188,6 +153,13 @@ Following the paper (Section 5.4):
188
  journal={arXiv preprint arXiv:2512.18134},
189
  year={2024}
190
  }
 
 
 
 
 
 
 
191
  ```
192
 
193
  ## Related Work
@@ -198,3 +170,6 @@ Following the paper (Section 5.4):
198
  - [CUTLASS 3.x](https://github.com/NVIDIA/cutlass) β€” NVIDIA GEMM templates with WS
199
  - [Tawa](https://arxiv.org/abs/2510.14719) β€” Automatic WS compiler (downstream of Twill)
200
  - [Cypress](https://arxiv.org/abs/2504.07004) β€” Task-based GPU programming model
 
 
 
 
1
+ # Twill + GauS: Optimal GPU Kernel Scheduling
2
 
3
+ Implementation of two complementary papers for GPU kernel scheduling optimization:
4
 
5
+ 1. **[Twill](https://arxiv.org/abs/2512.18134)** β€” "Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs" (Soi et al., NVIDIA Research, 2024). Exact ILP+SMT solver that finds provably optimal schedules.
6
 
7
+ 2. **[GauS](https://arxiv.org/abs/2602.20427)** β€” "Differentiable Scheduling Optimization via Gaussian Reparameterization" (Cai et al., 2026). Scalable gradient-based solver using Gaussian distributions and Augmented Lagrangian Method.
8
 
9
+ ## When to Use Which
10
 
11
+ | Property | Twill (ILP+SMT) | GauS (Differentiable) |
12
+ |----------|-----------------|----------------------|
13
+ | **Optimality** | Provably optimal | Approximate (Pareto-optimal) |
14
+ | **Speed on small graphs** | Fast (< 1s for 3-7 nodes) | Slower (~10s for 3-7 nodes) |
15
+ | **Scalability** | Exponential in graph size | O(|V|) parameters, GPU-accelerable |
16
+ | **Warp specialization** | Full joint SWP+WS | Schedule only (no WS) |
17
+ | **Best for** | Kernels with < 50 ops | Large graphs with 100+ ops |
18
 
19
+ ## Architecture
20
 
21
  ```
22
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
23
+ β”‚ DependenceGraph β”‚
24
+ β”‚ G = (V, E), Machine Description, RRTs β”‚
25
+ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
26
+ β”‚ Twill Solver β”‚ GauS Solver β”‚
27
+ β”‚ β”‚ β”‚
28
+ β”‚ Phase 1: ILP β”‚ Gaussian Reparameterization β”‚
29
+ β”‚ (CBC/PuLP) β”‚ X_i ~ N(ΞΌ_i, Οƒ_iΒ²) β”‚
30
+ β”‚ β†’ Modulo Schedule β”‚ β†’ P_i^d = Ξ¦((d+0.5-ΞΌ)/Οƒ) β”‚
31
+ β”‚ β”‚ - Ξ¦((d-0.5-ΞΌ)/Οƒ) β”‚
32
+ β”‚ Phase 2: SMT β”‚ β”‚
33
+ β”‚ (Z3/QFLIA) β”‚ Augmented Lagrangian Method β”‚
34
+ β”‚ β†’ Joint SWP + WS β”‚ β†’ Adam optimizer on (ΞΌ, Οƒ) β”‚
35
+ β”‚ β”‚ β†’ Dependency + Resource + β”‚
36
+ β”‚ Cost Norm (Β§5.2) β”‚ Modulo + Recurrence losses β”‚
37
+ β”‚ β†’ Ratio-preservingβ”‚ β”‚
38
+ β”‚ cycle count β”‚ Legalization Heuristics β”‚
39
+ β”‚ reduction β”‚ β†’ Topological pass (regular) β”‚
40
+ β”‚ β”‚ β†’ Fixed-point iteration (modulo) β”‚
41
+ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
42
+ β”‚ Code Generation β”‚
43
+ β”‚ Pseudocode Β· CUDA Skeleton Β· Pipelined Schedule β”‚
44
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
45
  ```
46
 
47
  ## Quick Start
48
 
49
  ```bash
50
+ pip install pulp z3-solver numpy matplotlib torch
51
  ```
52
 
53
+ ### Twill (exact solver)
54
  ```python
55
  from twill.kernels import flash_attention_forward_simplified
56
  from twill.twill_solver import twill_solve
 
57
 
 
58
  graph = flash_attention_forward_simplified()
 
 
59
  result = twill_solve(graph, max_I=5, verbose=True)
60
+ # β†’ I=2, schedule: S@0, P@2, O@3 (proves FA3 optimal)
 
 
61
  ```
62
 
63
+ ### GauS (differentiable solver)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  ```python
65
+ from twill.kernels import flash_attention_forward_hopper
66
+ from twill.gaus_solver import gaus_solve_twill_graph
 
67
 
68
+ graph = flash_attention_forward_hopper()
69
+ result = gaus_solve_twill_graph(graph, target_II=4, D=20, verbose=True)
70
+ # β†’ II=4, feasible schedule found via gradient descent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ```
72
 
73
+ ### GauS on custom large graphs
 
 
 
74
  ```python
75
+ from twill.gaus_solver import GauSSolver, generate_random_dag
76
 
77
+ graph = generate_random_dag(num_nodes=1000, edge_density=0.01, num_back_edges=50)
78
+ solver = GauSSolver(graph, D=500, lr=0.01)
79
+ result = solver.solve_modulo(II=10, R_cap=50.0, max_iters=2000)
80
+ ```
81
 
82
+ ## Pre-built Kernel Descriptions
 
83
 
84
+ | Kernel | Architecture | Twill Result | GauS Result |
85
+ |--------|-------------|-------------|-------------|
86
+ | `flash_attention_forward_simplified()` | Hopper | I=2 βœ“ (optimal) | II=2 βœ“ (feasible) |
87
+ | `flash_attention_forward_hopper()` | Hopper (H100) | I=4, FA3 pipeline βœ“ | II=4 βœ“ (feasible) |
88
+ | `flash_attention_forward_blackwell()` | Blackwell (B200) | I=2, FA4 strategy βœ“ | II=2 βœ“ (feasible) |
89
+ | `simple_gemm_pipeline()` | Hopper | I=2, TMA on producer βœ“ | II=2 βœ“ (feasible) |
90
 
91
  ## Module Structure
92
 
93
  ```
94
  twill/
95
+ β”œβ”€β”€ __init__.py # Package exports (v0.2.0)
96
  β”œβ”€β”€ graph.py # DependenceGraph, Instruction, RRT, MachineDescription
97
+ β”œβ”€β”€ cost_normalization.py # Β§5.2: ILP-based cycle count normalization
98
+ β”œβ”€β”€ modulo_scheduler.py # Twill Phase 1: ILP modulo scheduling (CBC)
99
+ β”œβ”€β”€ smt_joint.py # Twill Phase 2: SMT joint SWP+WS (Z3)
100
+ β”œβ”€β”€ twill_solver.py # Twill Algorithm 1: Main search procedure
101
+ β”œβ”€β”€ gaus_solver.py # GauS: Differentiable scheduling (PyTorch)
102
  β”œβ”€β”€ codegen.py # Code generation (pseudocode, CUDA skeleton)
103
  β”œβ”€β”€ visualization.py # Schedule visualization (text + matplotlib)
104
  └── kernels.py # Pre-built kernel descriptions (FMHA, GEMM)
105
  ```
106
 
107
+ ## GauS: Technical Details
 
 
 
 
 
 
 
108
 
109
+ ### Gaussian Reparameterization (Β§3.1)
110
+ Each operator `v_i` is modeled as `X_i ~ N(ΞΌ_i, Οƒ_iΒ²)` with only **2|V| parameters** (vs. DΒ·|V| for categorical approaches). The probability of scheduling at step `d`:
 
111
 
112
+ ```
113
+ P_i^d = Ξ¦((d+0.5-ΞΌ_i)/Οƒ_i) - Ξ¦((d-0.5-ΞΌ_i)/Οƒ_i)
114
+ ```
 
115
 
116
+ ### Differentiable Loss Functions
117
+ - **Dependency**: Expected topological violations via CDF products
118
+ - **Resource**: LogSumExp smooth-max of per-step usage + ReLU violations
119
+ - **Modulo Resource**: Wrapped Gaussian probabilities into II-slot reservation table
120
+ - **Recurrence**: Expected back-edge constraint violations
121
+ - **Memory**: CDF-based active lifetime estimation
122
+ - **Communication**: Expected edge lengths (compactness)
123
 
124
+ ### Augmented Lagrangian Method (ALM)
125
+ ```
126
+ L_total = L_primary + Σ_i (λ_i · V_i + ρ/2 · ||V_i||²)
127
+ Ξ»_i ← Ξ»_i + ρ Β· V_i (dual update)
128
+ ```
129
+ Hyperparameters (from paper): `lr=0.01, ρ=1e-4, Ο„=1e-2, ΞΊ=1/6`
130
 
131
  ## Tests
132
 
133
  ```bash
134
+ python test_twill.py # Twill: 6/6 pass (~5s)
135
+ python test_gaus.py # GauS: 7/7 pass (~30s)
 
 
 
 
 
 
 
 
 
 
 
136
  ```
137
 
138
+ ## Solvers Used
139
 
140
+ | Component | Solver | Theory | Paper |
141
+ |-----------|--------|--------|-------|
142
+ | Twill Cost Normalization | CBC (PuLP) | ILP | Soi et al. Β§5.2 |
143
+ | Twill Modulo Scheduling | CBC (PuLP) | ILP | Soi et al. Β§3.1 |
144
+ | Twill Joint SWP+WS | Z3 | QFLIA (SMT) | Soi et al. Β§4 |
145
+ | GauS Scheduling | PyTorch (Adam) | Differentiable | Cai et al. Β§3 |
146
 
147
+ ## Citations
148
 
149
  ```bibtex
150
  @article{soi2024twill,
 
153
  journal={arXiv preprint arXiv:2512.18134},
154
  year={2024}
155
  }
156
+
157
+ @article{cai2026gaus,
158
+ title={GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization},
159
+ author={Cai, Yaohui and others},
160
+ journal={arXiv preprint arXiv:2602.20427},
161
+ year={2026}
162
+ }
163
  ```
164
 
165
  ## Related Work
 
170
  - [CUTLASS 3.x](https://github.com/NVIDIA/cutlass) β€” NVIDIA GEMM templates with WS
171
  - [Tawa](https://arxiv.org/abs/2510.14719) β€” Automatic WS compiler (downstream of Twill)
172
  - [Cypress](https://arxiv.org/abs/2504.07004) β€” Task-based GPU programming model
173
+ - [Nautilus](https://arxiv.org/abs/2604.14825) β€” Auto-scheduling tensor compiler (fills Twill's tile-size gap)
174
+ - [MPK](https://arxiv.org/abs/2512.22219) β€” Cross-kernel software pipelining
175
+ - [GS-Schedule](https://github.com/Yu-Maryland/Differentiable_Scheduler_ICML24) β€” GauS predecessor (categorical approach)