Upload README.md
Browse files
README.md
CHANGED
|
@@ -1,185 +1,150 @@
|
|
| 1 |
-
# Twill: Optimal
|
| 2 |
|
| 3 |
-
Implementation of
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
```
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
```
|
| 27 |
|
| 28 |
## Quick Start
|
| 29 |
|
| 30 |
```bash
|
| 31 |
-
pip install pulp z3-solver numpy matplotlib
|
| 32 |
```
|
| 33 |
|
|
|
|
| 34 |
```python
|
| 35 |
from twill.kernels import flash_attention_forward_simplified
|
| 36 |
from twill.twill_solver import twill_solve
|
| 37 |
-
from twill.visualization import visualize_schedule
|
| 38 |
|
| 39 |
-
# Build the simplified Flash Attention dependence graph (Figure 1)
|
| 40 |
graph = flash_attention_forward_simplified()
|
| 41 |
-
|
| 42 |
-
# Run the full Twill solver
|
| 43 |
result = twill_solve(graph, max_I=5, verbose=True)
|
| 44 |
-
|
| 45 |
-
# Visualize the result
|
| 46 |
-
print(visualize_schedule(graph, result.joint_result))
|
| 47 |
```
|
| 48 |
|
| 49 |
-
|
| 50 |
-
```
|
| 51 |
-
SOLUTION FOUND in 0.12s
|
| 52 |
-
Initiation Interval I = 2 β optimal!
|
| 53 |
-
Schedule Length L = 4
|
| 54 |
-
Overlapping copies = 2
|
| 55 |
-
Schedule: S@0, P@2, O@3 β S extracted into prologue
|
| 56 |
-
Warp Assignment: all on warp 0 (no variable-latency ops)
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
## Pre-built Kernel Descriptions
|
| 60 |
-
|
| 61 |
-
| Kernel | Section | Architecture | Key Result |
|
| 62 |
-
|--------|---------|-------------|------------|
|
| 63 |
-
| `flash_attention_forward_simplified()` | Β§3 (Figure 1) | Hopper | I=2, SWP extracts S into prologue |
|
| 64 |
-
| `flash_attention_forward_hopper()` | Β§6.2.1 | Hopper (H100) | Rediscovers FA3 pipeline + ping-pong |
|
| 65 |
-
| `flash_attention_forward_blackwell()` | Β§6.2.2 | Blackwell (B200) | Rediscovers FA4 strategy |
|
| 66 |
-
| `simple_gemm_pipeline()` | β | Hopper | Load-compute overlap, TMA on producer warp |
|
| 67 |
-
|
| 68 |
-
## Custom Kernels
|
| 69 |
-
|
| 70 |
-
Define your own kernel dependence graph:
|
| 71 |
-
|
| 72 |
```python
|
| 73 |
-
from twill.
|
| 74 |
-
from twill.
|
| 75 |
-
from twill.twill_solver import twill_solve
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
instructions=[
|
| 81 |
-
{"name": "load_A", "cycles": 1, "fu": "TMA", "variable_latency": True, "streaming": True},
|
| 82 |
-
{"name": "load_B", "cycles": 1, "fu": "TMA", "variable_latency": True, "streaming": True},
|
| 83 |
-
{"name": "gemm", "cycles": 2, "fu": "TC"},
|
| 84 |
-
{"name": "relu", "cycles": 1, "fu": "EXP"},
|
| 85 |
-
],
|
| 86 |
-
edges=[
|
| 87 |
-
{"src": "load_A", "dst": "gemm", "delay": 1},
|
| 88 |
-
{"src": "load_B", "dst": "gemm", "delay": 1},
|
| 89 |
-
{"src": "gemm", "dst": "relu", "delay": 2},
|
| 90 |
-
{"src": "relu", "dst": "relu", "delay": 1, "delta": 1}, # loop-carried
|
| 91 |
-
],
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
result = twill_solve(graph, verbose=True)
|
| 95 |
```
|
| 96 |
|
| 97 |
-
##
|
| 98 |
-
|
| 99 |
-
Twill generates three output formats:
|
| 100 |
-
|
| 101 |
```python
|
| 102 |
-
from twill.
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
| 106 |
|
| 107 |
-
#
|
| 108 |
-
print(generate_cuda_skeleton(graph, result.joint_result))
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
``
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
## Module Structure
|
| 115 |
|
| 116 |
```
|
| 117 |
twill/
|
| 118 |
-
βββ __init__.py # Package exports
|
| 119 |
βββ graph.py # DependenceGraph, Instruction, RRT, MachineDescription
|
| 120 |
-
βββ cost_normalization.py #
|
| 121 |
-
βββ modulo_scheduler.py # Phase 1: ILP modulo scheduling (CBC
|
| 122 |
-
βββ smt_joint.py # Phase 2: SMT joint SWP+WS (Z3
|
| 123 |
-
βββ twill_solver.py # Algorithm 1: Main search procedure
|
|
|
|
| 124 |
βββ codegen.py # Code generation (pseudocode, CUDA skeleton)
|
| 125 |
βββ visualization.py # Schedule visualization (text + matplotlib)
|
| 126 |
βββ kernels.py # Pre-built kernel descriptions (FMHA, GEMM)
|
| 127 |
```
|
| 128 |
|
| 129 |
-
##
|
| 130 |
-
|
| 131 |
-
### Figure 4: Modulo Scheduling Constraints
|
| 132 |
-
- **Uniqueness**: Each operation scheduled exactly once per iteration copy
|
| 133 |
-
- **Consistency**: Modulo structure preserved across copies (offset by I)
|
| 134 |
-
- **Completion**: Operations must finish before end of schedule
|
| 135 |
-
- **Dependence**: Data dependencies respected across iterations
|
| 136 |
-
- **Capacity**: Functional unit capacities not exceeded
|
| 137 |
|
| 138 |
-
###
|
| 139 |
-
|
| 140 |
-
- **Liveness**: SSA-based backward dataflow for variable lifetimes
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
- **WarpCapacity**: Per-warp resource budget respected
|
| 146 |
|
| 147 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
## Tests
|
| 156 |
|
| 157 |
```bash
|
| 158 |
-
python test_twill.py
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
```
|
| 162 |
-
β PASS Cost Normalization
|
| 163 |
-
β PASS Modulo Scheduling Only
|
| 164 |
-
β PASS Simplified FA (Figure 1)
|
| 165 |
-
β PASS Simple GEMM
|
| 166 |
-
β PASS Hopper FMHA Forward
|
| 167 |
-
β PASS Blackwell FMHA Forward
|
| 168 |
-
|
| 169 |
-
Passed: 6/6
|
| 170 |
-
Total time: ~5s
|
| 171 |
```
|
| 172 |
|
| 173 |
-
##
|
| 174 |
|
| 175 |
-
|
| 176 |
-
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
|
| 182 |
-
##
|
| 183 |
|
| 184 |
```bibtex
|
| 185 |
@article{soi2024twill,
|
|
@@ -188,6 +153,13 @@ Following the paper (Section 5.4):
|
|
| 188 |
journal={arXiv preprint arXiv:2512.18134},
|
| 189 |
year={2024}
|
| 190 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
```
|
| 192 |
|
| 193 |
## Related Work
|
|
@@ -198,3 +170,6 @@ Following the paper (Section 5.4):
|
|
| 198 |
- [CUTLASS 3.x](https://github.com/NVIDIA/cutlass) β NVIDIA GEMM templates with WS
|
| 199 |
- [Tawa](https://arxiv.org/abs/2510.14719) β Automatic WS compiler (downstream of Twill)
|
| 200 |
- [Cypress](https://arxiv.org/abs/2504.07004) β Task-based GPU programming model
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Twill + GauS: Optimal GPU Kernel Scheduling
|
| 2 |
|
| 3 |
+
Implementation of two complementary papers for GPU kernel scheduling optimization:
|
| 4 |
|
| 5 |
+
1. **[Twill](https://arxiv.org/abs/2512.18134)** β "Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs" (Soi et al., NVIDIA Research, 2024). Exact ILP+SMT solver that finds provably optimal schedules.
|
| 6 |
|
| 7 |
+
2. **[GauS](https://arxiv.org/abs/2602.20427)** β "Differentiable Scheduling Optimization via Gaussian Reparameterization" (Cai et al., 2026). Scalable gradient-based solver using Gaussian distributions and Augmented Lagrangian Method.
|
| 8 |
|
| 9 |
+
## When to Use Which
|
| 10 |
|
| 11 |
+
| Property | Twill (ILP+SMT) | GauS (Differentiable) |
|
| 12 |
+
|----------|-----------------|----------------------|
|
| 13 |
+
| **Optimality** | Provably optimal | Approximate (Pareto-optimal) |
|
| 14 |
+
| **Speed on small graphs** | Fast (< 1s for 3-7 nodes) | Slower (~10s for 3-7 nodes) |
|
| 15 |
+
| **Scalability** | Exponential in graph size | O(|V|) parameters, GPU-accelerable |
|
| 16 |
+
| **Warp specialization** | Full joint SWP+WS | Schedule only (no WS) |
|
| 17 |
+
| **Best for** | Kernels with < 50 ops | Large graphs with 100+ ops |
|
| 18 |
|
| 19 |
+
## Architecture
|
| 20 |
|
| 21 |
```
|
| 22 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
β DependenceGraph β
|
| 24 |
+
β G = (V, E), Machine Description, RRTs β
|
| 25 |
+
ββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββ€
|
| 26 |
+
β Twill Solver β GauS Solver β
|
| 27 |
+
β β β
|
| 28 |
+
β Phase 1: ILP β Gaussian Reparameterization β
|
| 29 |
+
β (CBC/PuLP) β X_i ~ N(ΞΌ_i, Ο_iΒ²) β
|
| 30 |
+
β β Modulo Schedule β β P_i^d = Ξ¦((d+0.5-ΞΌ)/Ο) β
|
| 31 |
+
β β - Ξ¦((d-0.5-ΞΌ)/Ο) β
|
| 32 |
+
β Phase 2: SMT β β
|
| 33 |
+
β (Z3/QFLIA) β Augmented Lagrangian Method β
|
| 34 |
+
β β Joint SWP + WS β β Adam optimizer on (ΞΌ, Ο) β
|
| 35 |
+
β β β Dependency + Resource + β
|
| 36 |
+
β Cost Norm (Β§5.2) β Modulo + Recurrence losses β
|
| 37 |
+
β β Ratio-preservingβ β
|
| 38 |
+
β cycle count β Legalization Heuristics β
|
| 39 |
+
β reduction β β Topological pass (regular) β
|
| 40 |
+
β β β Fixed-point iteration (modulo) β
|
| 41 |
+
ββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββββ€
|
| 42 |
+
β Code Generation β
|
| 43 |
+
β Pseudocode Β· CUDA Skeleton Β· Pipelined Schedule β
|
| 44 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
```
|
| 46 |
|
| 47 |
## Quick Start
|
| 48 |
|
| 49 |
```bash
|
| 50 |
+
pip install pulp z3-solver numpy matplotlib torch
|
| 51 |
```
|
| 52 |
|
| 53 |
+
### Twill (exact solver)
|
| 54 |
```python
|
| 55 |
from twill.kernels import flash_attention_forward_simplified
|
| 56 |
from twill.twill_solver import twill_solve
|
|
|
|
| 57 |
|
|
|
|
| 58 |
graph = flash_attention_forward_simplified()
|
|
|
|
|
|
|
| 59 |
result = twill_solve(graph, max_I=5, verbose=True)
|
| 60 |
+
# β I=2, schedule: S@0, P@2, O@3 (proves FA3 optimal)
|
|
|
|
|
|
|
| 61 |
```
|
| 62 |
|
| 63 |
+
### GauS (differentiable solver)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
```python
|
| 65 |
+
from twill.kernels import flash_attention_forward_hopper
|
| 66 |
+
from twill.gaus_solver import gaus_solve_twill_graph
|
|
|
|
| 67 |
|
| 68 |
+
graph = flash_attention_forward_hopper()
|
| 69 |
+
result = gaus_solve_twill_graph(graph, target_II=4, D=20, verbose=True)
|
| 70 |
+
# β II=4, feasible schedule found via gradient descent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
```
|
| 72 |
|
| 73 |
+
### GauS on custom large graphs
|
|
|
|
|
|
|
|
|
|
| 74 |
```python
|
| 75 |
+
from twill.gaus_solver import GauSSolver, generate_random_dag
|
| 76 |
|
| 77 |
+
graph = generate_random_dag(num_nodes=1000, edge_density=0.01, num_back_edges=50)
|
| 78 |
+
solver = GauSSolver(graph, D=500, lr=0.01)
|
| 79 |
+
result = solver.solve_modulo(II=10, R_cap=50.0, max_iters=2000)
|
| 80 |
+
```
|
| 81 |
|
| 82 |
+
## Pre-built Kernel Descriptions
|
|
|
|
| 83 |
|
| 84 |
+
| Kernel | Architecture | Twill Result | GauS Result |
|
| 85 |
+
|--------|-------------|-------------|-------------|
|
| 86 |
+
| `flash_attention_forward_simplified()` | Hopper | I=2 β (optimal) | II=2 β (feasible) |
|
| 87 |
+
| `flash_attention_forward_hopper()` | Hopper (H100) | I=4, FA3 pipeline β | II=4 β (feasible) |
|
| 88 |
+
| `flash_attention_forward_blackwell()` | Blackwell (B200) | I=2, FA4 strategy β | II=2 β (feasible) |
|
| 89 |
+
| `simple_gemm_pipeline()` | Hopper | I=2, TMA on producer β | II=2 β (feasible) |
|
| 90 |
|
| 91 |
## Module Structure
|
| 92 |
|
| 93 |
```
|
| 94 |
twill/
|
| 95 |
+
βββ __init__.py # Package exports (v0.2.0)
|
| 96 |
βββ graph.py # DependenceGraph, Instruction, RRT, MachineDescription
|
| 97 |
+
βββ cost_normalization.py # Β§5.2: ILP-based cycle count normalization
|
| 98 |
+
βββ modulo_scheduler.py # Twill Phase 1: ILP modulo scheduling (CBC)
|
| 99 |
+
βββ smt_joint.py # Twill Phase 2: SMT joint SWP+WS (Z3)
|
| 100 |
+
βββ twill_solver.py # Twill Algorithm 1: Main search procedure
|
| 101 |
+
βββ gaus_solver.py # GauS: Differentiable scheduling (PyTorch)
|
| 102 |
βββ codegen.py # Code generation (pseudocode, CUDA skeleton)
|
| 103 |
βββ visualization.py # Schedule visualization (text + matplotlib)
|
| 104 |
βββ kernels.py # Pre-built kernel descriptions (FMHA, GEMM)
|
| 105 |
```
|
| 106 |
|
| 107 |
+
## GauS: Technical Details
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
### Gaussian Reparameterization (Β§3.1)
|
| 110 |
+
Each operator `v_i` is modeled as `X_i ~ N(ΞΌ_i, Ο_iΒ²)` with only **2|V| parameters** (vs. DΒ·|V| for categorical approaches). The probability of scheduling at step `d`:
|
|
|
|
| 111 |
|
| 112 |
+
```
|
| 113 |
+
P_i^d = Ξ¦((d+0.5-ΞΌ_i)/Ο_i) - Ξ¦((d-0.5-ΞΌ_i)/Ο_i)
|
| 114 |
+
```
|
|
|
|
| 115 |
|
| 116 |
+
### Differentiable Loss Functions
|
| 117 |
+
- **Dependency**: Expected topological violations via CDF products
|
| 118 |
+
- **Resource**: LogSumExp smooth-max of per-step usage + ReLU violations
|
| 119 |
+
- **Modulo Resource**: Wrapped Gaussian probabilities into II-slot reservation table
|
| 120 |
+
- **Recurrence**: Expected back-edge constraint violations
|
| 121 |
+
- **Memory**: CDF-based active lifetime estimation
|
| 122 |
+
- **Communication**: Expected edge lengths (compactness)
|
| 123 |
|
| 124 |
+
### Augmented Lagrangian Method (ALM)
|
| 125 |
+
```
|
| 126 |
+
L_total = L_primary + Ξ£_i (Ξ»_i Β· V_i + Ο/2 Β· ||V_i||Β²)
|
| 127 |
+
Ξ»_i β Ξ»_i + Ο Β· V_i (dual update)
|
| 128 |
+
```
|
| 129 |
+
Hyperparameters (from paper): `lr=0.01, Ο=1e-4, Ο=1e-2, ΞΊ=1/6`
|
| 130 |
|
| 131 |
## Tests
|
| 132 |
|
| 133 |
```bash
|
| 134 |
+
python test_twill.py # Twill: 6/6 pass (~5s)
|
| 135 |
+
python test_gaus.py # GauS: 7/7 pass (~30s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
```
|
| 137 |
|
| 138 |
+
## Solvers Used
|
| 139 |
|
| 140 |
+
| Component | Solver | Theory | Paper |
|
| 141 |
+
|-----------|--------|--------|-------|
|
| 142 |
+
| Twill Cost Normalization | CBC (PuLP) | ILP | Soi et al. Β§5.2 |
|
| 143 |
+
| Twill Modulo Scheduling | CBC (PuLP) | ILP | Soi et al. Β§3.1 |
|
| 144 |
+
| Twill Joint SWP+WS | Z3 | QFLIA (SMT) | Soi et al. Β§4 |
|
| 145 |
+
| GauS Scheduling | PyTorch (Adam) | Differentiable | Cai et al. Β§3 |
|
| 146 |
|
| 147 |
+
## Citations
|
| 148 |
|
| 149 |
```bibtex
|
| 150 |
@article{soi2024twill,
|
|
|
|
| 153 |
journal={arXiv preprint arXiv:2512.18134},
|
| 154 |
year={2024}
|
| 155 |
}
|
| 156 |
+
|
| 157 |
+
@article{cai2026gaus,
|
| 158 |
+
title={GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization},
|
| 159 |
+
author={Cai, Yaohui and others},
|
| 160 |
+
journal={arXiv preprint arXiv:2602.20427},
|
| 161 |
+
year={2026}
|
| 162 |
+
}
|
| 163 |
```
|
| 164 |
|
| 165 |
## Related Work
|
|
|
|
| 170 |
- [CUTLASS 3.x](https://github.com/NVIDIA/cutlass) β NVIDIA GEMM templates with WS
|
| 171 |
- [Tawa](https://arxiv.org/abs/2510.14719) β Automatic WS compiler (downstream of Twill)
|
| 172 |
- [Cypress](https://arxiv.org/abs/2504.07004) β Task-based GPU programming model
|
| 173 |
+
- [Nautilus](https://arxiv.org/abs/2604.14825) β Auto-scheduling tensor compiler (fills Twill's tile-size gap)
|
| 174 |
+
- [MPK](https://arxiv.org/abs/2512.22219) β Cross-kernel software pipelining
|
| 175 |
+
- [GS-Schedule](https://github.com/Yu-Maryland/Differentiable_Scheduler_ICML24) β GauS predecessor (categorical approach)
|