File size: 8,252 Bytes
9e6c8b9
334c3f4
9e6c8b9
334c3f4
9e6c8b9
334c3f4
9e6c8b9
334c3f4
9e6c8b9
334c3f4
9e6c8b9
 
 
 
 
 
 
334c3f4
9e6c8b9
334c3f4
 
9e6c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334c3f4
 
 
 
 
9e6c8b9
334c3f4
 
9e6c8b9
334c3f4
 
 
 
 
 
9e6c8b9
334c3f4
 
9e6c8b9
334c3f4
9e6c8b9
 
334c3f4
9e6c8b9
 
 
334c3f4
 
9e6c8b9
334c3f4
9e6c8b9
334c3f4
9e6c8b9
 
 
 
334c3f4
9e6c8b9
334c3f4
9e6c8b9
 
 
 
 
 
334c3f4
 
 
 
 
9e6c8b9
334c3f4
9e6c8b9
 
 
 
 
334c3f4
 
 
 
 
9e6c8b9
334c3f4
9e6c8b9
 
334c3f4
9e6c8b9
 
 
334c3f4
9e6c8b9
 
 
 
 
 
 
334c3f4
9e6c8b9
 
 
 
 
 
334c3f4
 
 
 
9e6c8b9
 
334c3f4
 
9e6c8b9
334c3f4
9e6c8b9
 
 
 
 
 
334c3f4
9e6c8b9
334c3f4
 
 
 
 
 
 
 
9e6c8b9
 
 
 
 
 
 
334c3f4
 
 
 
 
 
 
 
 
 
9e6c8b9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Twill + GauS: Optimal GPU Kernel Scheduling

Implementation of two complementary papers for GPU kernel scheduling optimization:

1. **[Twill](https://arxiv.org/abs/2512.18134)** β€” "Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs" (Soi et al., NVIDIA Research, 2024). Exact ILP+SMT solver that finds provably optimal schedules.

2. **[GauS](https://arxiv.org/abs/2602.20427)** β€” "Differentiable Scheduling Optimization via Gaussian Reparameterization" (Cai et al., 2026). Scalable gradient-based solver using Gaussian distributions and Augmented Lagrangian Method.

## When to Use Which

| Property | Twill (ILP+SMT) | GauS (Differentiable) |
|----------|-----------------|----------------------|
| **Optimality** | Provably optimal | Approximate (Pareto-optimal) |
| **Speed on small graphs** | Fast (< 1s for 3-7 nodes) | Slower (~10s for 3-7 nodes) |
| **Scalability** | Exponential in graph size | O(|V|) parameters, GPU-accelerable |
| **Warp specialization** | Full joint SWP+WS | Schedule only (no WS) |
| **Best for** | Kernels with < 50 ops | Large graphs with 100+ ops |

## Architecture

```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    DependenceGraph                       β”‚
β”‚  G = (V, E), Machine Description, RRTs                  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚    Twill Solver    β”‚         GauS Solver                β”‚
β”‚                    β”‚                                    β”‚
β”‚  Phase 1: ILP      β”‚  Gaussian Reparameterization       β”‚
β”‚  (CBC/PuLP)        β”‚  X_i ~ N(ΞΌ_i, Οƒ_iΒ²)               β”‚
β”‚  β†’ Modulo Schedule β”‚  β†’ P_i^d = Ξ¦((d+0.5-ΞΌ)/Οƒ)        β”‚
β”‚                    β”‚        - Ξ¦((d-0.5-ΞΌ)/Οƒ)            β”‚
β”‚  Phase 2: SMT      β”‚                                    β”‚
β”‚  (Z3/QFLIA)       β”‚  Augmented Lagrangian Method        β”‚
β”‚  β†’ Joint SWP + WS β”‚  β†’ Adam optimizer on (ΞΌ, Οƒ)        β”‚
β”‚                    β”‚  β†’ Dependency + Resource +          β”‚
β”‚  Cost Norm (Β§5.2)  β”‚    Modulo + Recurrence losses      β”‚
β”‚  β†’ Ratio-preservingβ”‚                                    β”‚
β”‚    cycle count     β”‚  Legalization Heuristics            β”‚
β”‚    reduction       β”‚  β†’ Topological pass (regular)      β”‚
β”‚                    β”‚  β†’ Fixed-point iteration (modulo)   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚               Code Generation                           β”‚
β”‚  Pseudocode Β· CUDA Skeleton Β· Pipelined Schedule        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```

## Quick Start

```bash
pip install pulp z3-solver numpy matplotlib torch
```

### Twill (exact solver)
```python
from twill.kernels import flash_attention_forward_simplified
from twill.twill_solver import twill_solve

graph = flash_attention_forward_simplified()
result = twill_solve(graph, max_I=5, verbose=True)
# β†’ I=2, schedule: S@0, P@2, O@3 (proves FA3 optimal)
```

### GauS (differentiable solver)
```python
from twill.kernels import flash_attention_forward_hopper
from twill.gaus_solver import gaus_solve_twill_graph

graph = flash_attention_forward_hopper()
result = gaus_solve_twill_graph(graph, target_II=4, D=20, verbose=True)
# β†’ II=4, feasible schedule found via gradient descent
```

### GauS on custom large graphs
```python
from twill.gaus_solver import GauSSolver, generate_random_dag

graph = generate_random_dag(num_nodes=1000, edge_density=0.01, num_back_edges=50)
solver = GauSSolver(graph, D=500, lr=0.01)
result = solver.solve_modulo(II=10, R_cap=50.0, max_iters=2000)
```

## Pre-built Kernel Descriptions

| Kernel | Architecture | Twill Result | GauS Result |
|--------|-------------|-------------|-------------|
| `flash_attention_forward_simplified()` | Hopper | I=2 βœ“ (optimal) | II=2 βœ“ (feasible) |
| `flash_attention_forward_hopper()` | Hopper (H100) | I=4, FA3 pipeline βœ“ | II=4 βœ“ (feasible) |
| `flash_attention_forward_blackwell()` | Blackwell (B200) | I=2, FA4 strategy βœ“ | II=2 βœ“ (feasible) |
| `simple_gemm_pipeline()` | Hopper | I=2, TMA on producer βœ“ | II=2 βœ“ (feasible) |

## Module Structure

```
twill/
β”œβ”€β”€ __init__.py              # Package exports (v0.2.0)
β”œβ”€β”€ graph.py                 # DependenceGraph, Instruction, RRT, MachineDescription
β”œβ”€β”€ cost_normalization.py    # Β§5.2: ILP-based cycle count normalization
β”œβ”€β”€ modulo_scheduler.py      # Twill Phase 1: ILP modulo scheduling (CBC)
β”œβ”€β”€ smt_joint.py             # Twill Phase 2: SMT joint SWP+WS (Z3)
β”œβ”€β”€ twill_solver.py          # Twill Algorithm 1: Main search procedure
β”œβ”€β”€ gaus_solver.py           # GauS: Differentiable scheduling (PyTorch)
β”œβ”€β”€ codegen.py               # Code generation (pseudocode, CUDA skeleton)
β”œβ”€β”€ visualization.py         # Schedule visualization (text + matplotlib)
└── kernels.py               # Pre-built kernel descriptions (FMHA, GEMM)
```

## GauS: Technical Details

### Gaussian Reparameterization (Β§3.1)
Each operator `v_i` is modeled as `X_i ~ N(ΞΌ_i, Οƒ_iΒ²)` with only **2|V| parameters** (vs. DΒ·|V| for categorical approaches). The probability of scheduling at step `d`:

```
P_i^d = Ξ¦((d+0.5-ΞΌ_i)/Οƒ_i) - Ξ¦((d-0.5-ΞΌ_i)/Οƒ_i)
```

### Differentiable Loss Functions
- **Dependency**: Expected topological violations via CDF products
- **Resource**: LogSumExp smooth-max of per-step usage + ReLU violations
- **Modulo Resource**: Wrapped Gaussian probabilities into II-slot reservation table
- **Recurrence**: Expected back-edge constraint violations
- **Memory**: CDF-based active lifetime estimation
- **Communication**: Expected edge lengths (compactness)

### Augmented Lagrangian Method (ALM)
```
L_total = L_primary + Σ_i (λ_i · V_i + ρ/2 · ||V_i||²)
Ξ»_i ← Ξ»_i + ρ Β· V_i    (dual update)
```
Hyperparameters (from paper): `lr=0.01, ρ=1e-4, Ο„=1e-2, ΞΊ=1/6`

## Tests

```bash
python test_twill.py   # Twill: 6/6 pass (~5s)
python test_gaus.py    # GauS:  7/7 pass (~30s)
```

## Solvers Used

| Component | Solver | Theory | Paper |
|-----------|--------|--------|-------|
| Twill Cost Normalization | CBC (PuLP) | ILP | Soi et al. Β§5.2 |
| Twill Modulo Scheduling | CBC (PuLP) | ILP | Soi et al. Β§3.1 |
| Twill Joint SWP+WS | Z3 | QFLIA (SMT) | Soi et al. Β§4 |
| GauS Scheduling | PyTorch (Adam) | Differentiable | Cai et al. Β§3 |

## Citations

```bibtex
@article{soi2024twill,
  title={Optimal Software Pipelining and Warp Specialization for Tensor Core GPUs},
  author={Soi, Rupanshu and others},
  journal={arXiv preprint arXiv:2512.18134},
  year={2024}
}

@article{cai2026gaus,
  title={GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization},
  author={Cai, Yaohui and others},
  journal={arXiv preprint arXiv:2602.20427},
  year={2026}
}
```

## Related Work

- [FlashAttention-3](https://arxiv.org/abs/2407.08608) β€” Hopper FMHA schedule that Twill rediscovers
- [FlashAttention-4](https://arxiv.org/abs/2603.05451) β€” Blackwell FMHA schedule that Twill rediscovers
- [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) β€” Warp-level kernel framework
- [CUTLASS 3.x](https://github.com/NVIDIA/cutlass) β€” NVIDIA GEMM templates with WS
- [Tawa](https://arxiv.org/abs/2510.14719) β€” Automatic WS compiler (downstream of Twill)
- [Cypress](https://arxiv.org/abs/2504.07004) β€” Task-based GPU programming model
- [Nautilus](https://arxiv.org/abs/2604.14825) β€” Auto-scheduling tensor compiler (fills Twill's tile-size gap)
- [MPK](https://arxiv.org/abs/2512.22219) β€” Cross-kernel software pipelining
- [GS-Schedule](https://github.com/Yu-Maryland/Differentiable_Scheduler_ICML24) β€” GauS predecessor (categorical approach)