twill-swp-ws / README.md
AshenNav's picture
Upload README.md
9e6c8b9 verified
# Twill + GauS: Optimal GPU Kernel Scheduling
Implementation of two complementary papers for GPU kernel scheduling optimization:
1. **[Twill](https://arxiv.org/abs/2512.18134)** β€” "Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs" (Soi et al., NVIDIA Research, 2024). Exact ILP+SMT solver that finds provably optimal schedules.
2. **[GauS](https://arxiv.org/abs/2602.20427)** β€” "Differentiable Scheduling Optimization via Gaussian Reparameterization" (Cai et al., 2026). Scalable gradient-based solver using Gaussian distributions and Augmented Lagrangian Method.
## When to Use Which
| Property | Twill (ILP+SMT) | GauS (Differentiable) |
|----------|-----------------|----------------------|
| **Optimality** | Provably optimal | Approximate (Pareto-optimal) |
| **Speed on small graphs** | Fast (< 1s for 3-7 nodes) | Slower (~10s for 3-7 nodes) |
| **Scalability** | Exponential in graph size | O(|V|) parameters, GPU-accelerable |
| **Warp specialization** | Full joint SWP+WS | Schedule only (no WS) |
| **Best for** | Kernels with < 50 ops | Large graphs with 100+ ops |
## Architecture
```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ DependenceGraph β”‚
β”‚ G = (V, E), Machine Description, RRTs β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Twill Solver β”‚ GauS Solver β”‚
β”‚ β”‚ β”‚
β”‚ Phase 1: ILP β”‚ Gaussian Reparameterization β”‚
β”‚ (CBC/PuLP) β”‚ X_i ~ N(ΞΌ_i, Οƒ_iΒ²) β”‚
β”‚ β†’ Modulo Schedule β”‚ β†’ P_i^d = Ξ¦((d+0.5-ΞΌ)/Οƒ) β”‚
β”‚ β”‚ - Ξ¦((d-0.5-ΞΌ)/Οƒ) β”‚
β”‚ Phase 2: SMT β”‚ β”‚
β”‚ (Z3/QFLIA) β”‚ Augmented Lagrangian Method β”‚
β”‚ β†’ Joint SWP + WS β”‚ β†’ Adam optimizer on (ΞΌ, Οƒ) β”‚
β”‚ β”‚ β†’ Dependency + Resource + β”‚
β”‚ Cost Norm (Β§5.2) β”‚ Modulo + Recurrence losses β”‚
β”‚ β†’ Ratio-preservingβ”‚ β”‚
β”‚ cycle count β”‚ Legalization Heuristics β”‚
β”‚ reduction β”‚ β†’ Topological pass (regular) β”‚
β”‚ β”‚ β†’ Fixed-point iteration (modulo) β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Code Generation β”‚
β”‚ Pseudocode Β· CUDA Skeleton Β· Pipelined Schedule β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
## Quick Start
```bash
pip install pulp z3-solver numpy matplotlib torch
```
### Twill (exact solver)
```python
from twill.kernels import flash_attention_forward_simplified
from twill.twill_solver import twill_solve
graph = flash_attention_forward_simplified()
result = twill_solve(graph, max_I=5, verbose=True)
# β†’ I=2, schedule: S@0, P@2, O@3 (proves FA3 optimal)
```
### GauS (differentiable solver)
```python
from twill.kernels import flash_attention_forward_hopper
from twill.gaus_solver import gaus_solve_twill_graph
graph = flash_attention_forward_hopper()
result = gaus_solve_twill_graph(graph, target_II=4, D=20, verbose=True)
# β†’ II=4, feasible schedule found via gradient descent
```
### GauS on custom large graphs
```python
from twill.gaus_solver import GauSSolver, generate_random_dag
graph = generate_random_dag(num_nodes=1000, edge_density=0.01, num_back_edges=50)
solver = GauSSolver(graph, D=500, lr=0.01)
result = solver.solve_modulo(II=10, R_cap=50.0, max_iters=2000)
```
## Pre-built Kernel Descriptions
| Kernel | Architecture | Twill Result | GauS Result |
|--------|-------------|-------------|-------------|
| `flash_attention_forward_simplified()` | Hopper | I=2 βœ“ (optimal) | II=2 βœ“ (feasible) |
| `flash_attention_forward_hopper()` | Hopper (H100) | I=4, FA3 pipeline βœ“ | II=4 βœ“ (feasible) |
| `flash_attention_forward_blackwell()` | Blackwell (B200) | I=2, FA4 strategy βœ“ | II=2 βœ“ (feasible) |
| `simple_gemm_pipeline()` | Hopper | I=2, TMA on producer βœ“ | II=2 βœ“ (feasible) |
## Module Structure
```
twill/
β”œβ”€β”€ __init__.py # Package exports (v0.2.0)
β”œβ”€β”€ graph.py # DependenceGraph, Instruction, RRT, MachineDescription
β”œβ”€β”€ cost_normalization.py # Β§5.2: ILP-based cycle count normalization
β”œβ”€β”€ modulo_scheduler.py # Twill Phase 1: ILP modulo scheduling (CBC)
β”œβ”€β”€ smt_joint.py # Twill Phase 2: SMT joint SWP+WS (Z3)
β”œβ”€β”€ twill_solver.py # Twill Algorithm 1: Main search procedure
β”œβ”€β”€ gaus_solver.py # GauS: Differentiable scheduling (PyTorch)
β”œβ”€β”€ codegen.py # Code generation (pseudocode, CUDA skeleton)
β”œβ”€β”€ visualization.py # Schedule visualization (text + matplotlib)
└── kernels.py # Pre-built kernel descriptions (FMHA, GEMM)
```
## GauS: Technical Details
### Gaussian Reparameterization (Β§3.1)
Each operator `v_i` is modeled as `X_i ~ N(ΞΌ_i, Οƒ_iΒ²)` with only **2|V| parameters** (vs. DΒ·|V| for categorical approaches). The probability of scheduling at step `d`:
```
P_i^d = Ξ¦((d+0.5-ΞΌ_i)/Οƒ_i) - Ξ¦((d-0.5-ΞΌ_i)/Οƒ_i)
```
### Differentiable Loss Functions
- **Dependency**: Expected topological violations via CDF products
- **Resource**: LogSumExp smooth-max of per-step usage + ReLU violations
- **Modulo Resource**: Wrapped Gaussian probabilities into II-slot reservation table
- **Recurrence**: Expected back-edge constraint violations
- **Memory**: CDF-based active lifetime estimation
- **Communication**: Expected edge lengths (compactness)
### Augmented Lagrangian Method (ALM)
```
L_total = L_primary + Σ_i (λ_i · V_i + ρ/2 · ||V_i||²)
Ξ»_i ← Ξ»_i + ρ Β· V_i (dual update)
```
Hyperparameters (from paper): `lr=0.01, ρ=1e-4, Ο„=1e-2, ΞΊ=1/6`
## Tests
```bash
python test_twill.py # Twill: 6/6 pass (~5s)
python test_gaus.py # GauS: 7/7 pass (~30s)
```
## Solvers Used
| Component | Solver | Theory | Paper |
|-----------|--------|--------|-------|
| Twill Cost Normalization | CBC (PuLP) | ILP | Soi et al. Β§5.2 |
| Twill Modulo Scheduling | CBC (PuLP) | ILP | Soi et al. Β§3.1 |
| Twill Joint SWP+WS | Z3 | QFLIA (SMT) | Soi et al. Β§4 |
| GauS Scheduling | PyTorch (Adam) | Differentiable | Cai et al. Β§3 |
## Citations
```bibtex
@article{soi2024twill,
title={Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs},
author={Soi, Rupanshu and others},
journal={arXiv preprint arXiv:2512.18134},
year={2024}
}
@article{cai2026gaus,
title={GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization},
author={Cai, Yaohui and others},
journal={arXiv preprint arXiv:2602.20427},
year={2026}
}
```
## Related Work
- [FlashAttention-3](https://arxiv.org/abs/2407.08608) β€” Hopper FMHA schedule that Twill rediscovers
- [FlashAttention-4](https://arxiv.org/abs/2603.05451) β€” Blackwell FMHA schedule that Twill rediscovers
- [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) β€” Warp-level kernel framework
- [CUTLASS 3.x](https://github.com/NVIDIA/cutlass) β€” NVIDIA GEMM templates with WS
- [Tawa](https://arxiv.org/abs/2510.14719) β€” Automatic WS compiler (downstream of Twill)
- [Cypress](https://arxiv.org/abs/2504.07004) β€” Task-based GPU programming model
- [Nautilus](https://arxiv.org/abs/2604.14825) β€” Auto-scheduling tensor compiler (fills Twill's tile-size gap)
- [MPK](https://arxiv.org/abs/2512.22219) β€” Cross-kernel software pipelining
- [GS-Schedule](https://github.com/Yu-Maryland/Differentiable_Scheduler_ICML24) β€” GauS predecessor (categorical approach)