File size: 8,252 Bytes
9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 334c3f4 9e6c8b9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | # Twill + GauS: Optimal GPU Kernel Scheduling
Implementation of two complementary papers for GPU kernel scheduling optimization:
1. **[Twill](https://arxiv.org/abs/2512.18134)** β "Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs" (Soi et al., NVIDIA Research, 2024). Exact ILP+SMT solver that finds provably optimal schedules.
2. **[GauS](https://arxiv.org/abs/2602.20427)** β "Differentiable Scheduling Optimization via Gaussian Reparameterization" (Cai et al., 2026). Scalable gradient-based solver using Gaussian distributions and Augmented Lagrangian Method.
## When to Use Which
| Property | Twill (ILP+SMT) | GauS (Differentiable) |
|----------|-----------------|----------------------|
| **Optimality** | Provably optimal | Approximate (Pareto-optimal) |
| **Speed on small graphs** | Fast (< 1s for 3-7 nodes) | Slower (~10s for 3-7 nodes) |
| **Scalability** | Exponential in graph size | O(|V|) parameters, GPU-accelerable |
| **Warp specialization** | Full joint SWP+WS | Schedule only (no WS) |
| **Best for** | Kernels with < 50 ops | Large graphs with 100+ ops |
## Architecture
```
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β DependenceGraph β
β G = (V, E), Machine Description, RRTs β
ββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββ€
β Twill Solver β GauS Solver β
β β β
β Phase 1: ILP β Gaussian Reparameterization β
β (CBC/PuLP) β X_i ~ N(ΞΌ_i, Ο_iΒ²) β
β β Modulo Schedule β β P_i^d = Ξ¦((d+0.5-ΞΌ)/Ο) β
β β - Ξ¦((d-0.5-ΞΌ)/Ο) β
β Phase 2: SMT β β
β (Z3/QFLIA) β Augmented Lagrangian Method β
β β Joint SWP + WS β β Adam optimizer on (ΞΌ, Ο) β
β β β Dependency + Resource + β
β Cost Norm (Β§5.2) β Modulo + Recurrence losses β
β β Ratio-preservingβ β
β cycle count β Legalization Heuristics β
β reduction β β Topological pass (regular) β
β β β Fixed-point iteration (modulo) β
ββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββββ€
β Code Generation β
β Pseudocode Β· CUDA Skeleton Β· Pipelined Schedule β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
```
## Quick Start
```bash
pip install pulp z3-solver numpy matplotlib torch
```
### Twill (exact solver)
```python
from twill.kernels import flash_attention_forward_simplified
from twill.twill_solver import twill_solve
graph = flash_attention_forward_simplified()
result = twill_solve(graph, max_I=5, verbose=True)
# β I=2, schedule: S@0, P@2, O@3 (proves FA3 optimal)
```
### GauS (differentiable solver)
```python
from twill.kernels import flash_attention_forward_hopper
from twill.gaus_solver import gaus_solve_twill_graph
graph = flash_attention_forward_hopper()
result = gaus_solve_twill_graph(graph, target_II=4, D=20, verbose=True)
# β II=4, feasible schedule found via gradient descent
```
### GauS on custom large graphs
```python
from twill.gaus_solver import GauSSolver, generate_random_dag
graph = generate_random_dag(num_nodes=1000, edge_density=0.01, num_back_edges=50)
solver = GauSSolver(graph, D=500, lr=0.01)
result = solver.solve_modulo(II=10, R_cap=50.0, max_iters=2000)
```
## Pre-built Kernel Descriptions
| Kernel | Architecture | Twill Result | GauS Result |
|--------|-------------|-------------|-------------|
| `flash_attention_forward_simplified()` | Hopper | I=2 β (optimal) | II=2 β (feasible) |
| `flash_attention_forward_hopper()` | Hopper (H100) | I=4, FA3 pipeline β | II=4 β (feasible) |
| `flash_attention_forward_blackwell()` | Blackwell (B200) | I=2, FA4 strategy β | II=2 β (feasible) |
| `simple_gemm_pipeline()` | Hopper | I=2, TMA on producer β | II=2 β (feasible) |
## Module Structure
```
twill/
βββ __init__.py # Package exports (v0.2.0)
βββ graph.py # DependenceGraph, Instruction, RRT, MachineDescription
βββ cost_normalization.py # Β§5.2: ILP-based cycle count normalization
βββ modulo_scheduler.py # Twill Phase 1: ILP modulo scheduling (CBC)
βββ smt_joint.py # Twill Phase 2: SMT joint SWP+WS (Z3)
βββ twill_solver.py # Twill Algorithm 1: Main search procedure
βββ gaus_solver.py # GauS: Differentiable scheduling (PyTorch)
βββ codegen.py # Code generation (pseudocode, CUDA skeleton)
βββ visualization.py # Schedule visualization (text + matplotlib)
βββ kernels.py # Pre-built kernel descriptions (FMHA, GEMM)
```
## GauS: Technical Details
### Gaussian Reparameterization (Β§3.1)
Each operator `v_i` is modeled as `X_i ~ N(ΞΌ_i, Ο_iΒ²)` with only **2|V| parameters** (vs. DΒ·|V| for categorical approaches). The probability of scheduling at step `d`:
```
P_i^d = Ξ¦((d+0.5-ΞΌ_i)/Ο_i) - Ξ¦((d-0.5-ΞΌ_i)/Ο_i)
```
### Differentiable Loss Functions
- **Dependency**: Expected topological violations via CDF products
- **Resource**: LogSumExp smooth-max of per-step usage + ReLU violations
- **Modulo Resource**: Wrapped Gaussian probabilities into II-slot reservation table
- **Recurrence**: Expected back-edge constraint violations
- **Memory**: CDF-based active lifetime estimation
- **Communication**: Expected edge lengths (compactness)
### Augmented Lagrangian Method (ALM)
```
L_total = L_primary + Ξ£_i (Ξ»_i Β· V_i + Ο/2 Β· ||V_i||Β²)
Ξ»_i β Ξ»_i + Ο Β· V_i (dual update)
```
Hyperparameters (from paper): `lr=0.01, Ο=1e-4, Ο=1e-2, ΞΊ=1/6`
## Tests
```bash
python test_twill.py # Twill: 6/6 pass (~5s)
python test_gaus.py # GauS: 7/7 pass (~30s)
```
## Solvers Used
| Component | Solver | Theory | Paper |
|-----------|--------|--------|-------|
| Twill Cost Normalization | CBC (PuLP) | ILP | Soi et al. Β§5.2 |
| Twill Modulo Scheduling | CBC (PuLP) | ILP | Soi et al. Β§3.1 |
| Twill Joint SWP+WS | Z3 | QFLIA (SMT) | Soi et al. Β§4 |
| GauS Scheduling | PyTorch (Adam) | Differentiable | Cai et al. Β§3 |
## Citations
```bibtex
@article{soi2024twill,
title={Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs},
author={Soi, Rupanshu and others},
journal={arXiv preprint arXiv:2512.18134},
year={2024}
}
@article{cai2026gaus,
title={GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization},
author={Cai, Yaohui and others},
journal={arXiv preprint arXiv:2602.20427},
year={2026}
}
```
## Related Work
- [FlashAttention-3](https://arxiv.org/abs/2407.08608) β Hopper FMHA schedule that Twill rediscovers
- [FlashAttention-4](https://arxiv.org/abs/2603.05451) β Blackwell FMHA schedule that Twill rediscovers
- [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) β Warp-level kernel framework
- [CUTLASS 3.x](https://github.com/NVIDIA/cutlass) β NVIDIA GEMM templates with WS
- [Tawa](https://arxiv.org/abs/2510.14719) β Automatic WS compiler (downstream of Twill)
- [Cypress](https://arxiv.org/abs/2504.07004) β Task-based GPU programming model
- [Nautilus](https://arxiv.org/abs/2604.14825) β Auto-scheduling tensor compiler (fills Twill's tile-size gap)
- [MPK](https://arxiv.org/abs/2512.22219) β Cross-kernel software pipelining
- [GS-Schedule](https://github.com/Yu-Maryland/Differentiable_Scheduler_ICML24) β GauS predecessor (categorical approach)
|