File size: 11,077 Bytes
67e6b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
#!/usr/bin/env python3
"""
Test suite for Twill: Optimal Software Pipelining and Warp Specialization.

Reproduces the key results from the paper:
1. Simplified FA schedule (Figure 1) - GEMM and EXP have equal cost
2. Full FMHA forward on Hopper (Section 6.2.1) - rediscovers FA3 pipeline
3. Full FMHA forward on Blackwell (Section 6.2.2) - rediscovers FA4 strategy
4. Simple GEMM pipeline - load-compute overlap
"""

import sys
import time
import numpy as np

sys.path.insert(0, '/app')

from twill.graph import (
    DependenceGraph, Instruction, DependenceEdge,
    hopper_machine, blackwell_machine, make_rrt,
)
from twill.cost_normalization import normalize_costs
from twill.modulo_scheduler import (
    optimal_modulo_schedule, compute_modular_rrt, validate_schedule,
)
from twill.smt_joint import swp_and_ws
from twill.twill_solver import twill_solve
from twill.codegen import generate_pipelined_code, generate_pseudocode, generate_cuda_skeleton
from twill.visualization import visualize_schedule, print_modular_rrt
from twill.kernels import (
    flash_attention_forward_simplified,
    flash_attention_forward_hopper,
    flash_attention_forward_blackwell,
    simple_gemm_pipeline,
)


def test_cost_normalization():
    """Test cost normalization (Section 5.2)."""
    print("\n" + "=" * 70)
    print("TEST: Cost Normalization (Section 5.2)")
    print("=" * 70)

    # Example from the paper: GEMM ~1000 cycles on Hopper
    costs = {"WGMMA": 1000, "EXP": 1000, "TMA": 500}
    normalized, F = normalize_costs(costs, U=10)
    print(f"Original:   {costs}")
    print(f"Normalized: {normalized} (F={F})")
    
    # Check ratio preservation
    for k1 in costs:
        for k2 in costs:
            if costs[k2] > 0 and normalized[k2] > 0:
                orig_ratio = costs[k1] / costs[k2]
                norm_ratio = normalized[k1] / normalized[k2]
                print(f"  Ratio {k1}/{k2}: orig={orig_ratio:.3f}, norm={norm_ratio:.3f}")

    # Larger example
    costs2 = {"WGMMA": 1024, "EXP": 960, "TMA_K": 512, "TMA_V": 512, "TMEM_R": 128, "TMEM_W": 128}
    normalized2, F2 = normalize_costs(costs2, U=20)
    print(f"\nOriginal:   {costs2}")
    print(f"Normalized: {normalized2} (F={F2})")

    print("\nβœ“ Cost normalization test passed")
    return True


def test_simplified_flash_attention():
    """Test the simplified FA example from Figure 1.
    
    Expected result from the paper:
    - I = 2 (one GEMM and one EXP per 2 cycles)
    - L = 4 (schedule length)
    - ceil(L/I) = 2 copies overlapped
    - S is extracted into prologue
    """
    print("\n" + "=" * 70)
    print("TEST: Simplified Flash Attention (Figure 1)")
    print("=" * 70)

    graph = flash_attention_forward_simplified()
    print(f"Graph: {graph}")
    print(f"Min I (resource bound): {graph.compute_min_initiation_interval()}")

    # Phase 1: Modulo scheduling
    print("\n--- Phase 1: Modulo Scheduling ---")
    schedule = optimal_modulo_schedule(graph, target_I=2, verbose=False)
    
    if schedule is None:
        print("Phase 1 FAILED for I=2, trying I=3...")
        schedule = optimal_modulo_schedule(graph, target_I=3, verbose=False)
    
    assert schedule is not None, "No modulo schedule found!"
    print(f"Schedule: {schedule}")
    
    valid, violations = validate_schedule(graph, schedule)
    print(f"Valid: {valid}")
    if not valid:
        for v in violations:
            print(f"  VIOLATION: {v}")
    
    print(f"\n{print_modular_rrt(graph, schedule)}")

    # Full Twill solve
    print("\n--- Full Twill Solve ---")
    result = twill_solve(
        graph,
        max_I=5,
        enable_cost_normalization=False,  # Costs already normalized (1 cycle each)
        enable_memory_constraints=False,  # Simplified example
        enable_warp_constraints=True,
        verbose=True,
    )

    if result:
        print(f"\n{result}")
        
        # Visualization
        viz = visualize_schedule(graph, result.joint_result,
                                output_path="/app/simplified_fa_schedule.png",
                                title="Simplified Flash Attention (Figure 1)")
        print(viz)
        
        # Code generation
        pseudo = generate_pseudocode(graph, result.joint_result)
        print(f"\n--- Generated Pseudocode ---")
        print(pseudo)

        # CUDA skeleton
        cuda = generate_cuda_skeleton(graph, result.joint_result)
        print(f"\n--- Generated CUDA Skeleton ---")
        print(cuda)
        
        return True
    else:
        print("FAILED: No solution found!")
        return False


def test_hopper_fmha_forward():
    """Test full FMHA forward on Hopper (Section 6.2.1).
    
    Expected: Twill rediscovers FA3's SWP strategy:
    - S=gemm(Q,K[0]) extracted into prologue
    - Ping-pong scheduling between EXP and TC
    - TMA loads on variable-latency (producer) warp
    """
    print("\n" + "=" * 70)
    print("TEST: FMHA Forward on Hopper (Section 6.2.1)")
    print("=" * 70)

    graph = flash_attention_forward_hopper()
    print(f"Graph: {graph}")
    print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}")
    print(f"Min I: {graph.compute_min_initiation_interval()}")

    result = twill_solve(
        graph,
        max_I=10,
        enable_cost_normalization=False,
        enable_memory_constraints=False,
        enable_warp_constraints=True,
        verbose=True,
    )

    if result:
        print(f"\n{result}")
        
        viz = visualize_schedule(graph, result.joint_result,
                                output_path="/app/hopper_fmha_schedule.png",
                                title="FMHA Forward on Hopper")
        print(viz)
        
        pseudo = generate_pseudocode(graph, result.joint_result)
        print(f"\n--- Generated Pseudocode ---")
        print(pseudo)

        # Verify key properties:
        wa = result.warp_assignment
        print(f"\n--- Verification ---")
        
        # TMA loads should be on variable-latency warp (W_vl = 0)
        k_warp = wa.warp_of("K_load")
        v_warp = wa.warp_of("V_load")
        print(f"K_load on warp {k_warp} (expected: 0/variable-latency)")
        print(f"V_load on warp {v_warp} (expected: 0/variable-latency)")
        
        # TC operations should be on compute warps (not warp 0)
        s_warp = wa.warp_of("S")
        o_warp = wa.warp_of("O")
        print(f"S (GEMM) on warp {s_warp} (expected: compute warp)")
        print(f"O (GEMM) on warp {o_warp} (expected: compute warp)")
        
        # EXP should be on a compute warp
        p_warp = wa.warp_of("P")
        print(f"P (EXP) on warp {p_warp} (expected: compute warp)")
        
        return True
    else:
        print("FAILED: No solution found!")
        return False


def test_blackwell_fmha_forward():
    """Test FMHA forward on Blackwell (Section 6.2.2).
    
    Expected: Twill rediscovers FA4's strategy:
    - Variable-latency ops (TMA) on separate warp
    - TC GEMMs on compute warps
    - TMEM read/write for accumulator management
    - Different pipeline structure than Hopper
    """
    print("\n" + "=" * 70)
    print("TEST: FMHA Forward on Blackwell (Section 6.2.2)")
    print("=" * 70)

    graph = flash_attention_forward_blackwell()
    print(f"Graph: {graph}")
    print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}")
    print(f"Min I: {graph.compute_min_initiation_interval()}")

    result = twill_solve(
        graph,
        max_I=15,
        enable_cost_normalization=False,
        enable_memory_constraints=False,
        enable_warp_constraints=True,
        verbose=True,
    )

    if result:
        print(f"\n{result}")
        
        viz = visualize_schedule(graph, result.joint_result,
                                output_path="/app/blackwell_fmha_schedule.png",
                                title="FMHA Forward on Blackwell")
        print(viz)
        
        pseudo = generate_pseudocode(graph, result.joint_result)
        print(f"\n--- Generated Pseudocode ---")
        print(pseudo)

        cuda = generate_cuda_skeleton(graph, result.joint_result)
        print(f"\n--- Generated CUDA Skeleton ---")
        print(cuda)
        
        return True
    else:
        print("FAILED: No solution found!")
        return False


def test_simple_gemm():
    """Test simple GEMM with load-compute overlap."""
    print("\n" + "=" * 70)
    print("TEST: Simple GEMM Pipeline")
    print("=" * 70)

    graph = simple_gemm_pipeline()
    print(f"Graph: {graph}")

    result = twill_solve(
        graph,
        max_I=6,
        enable_cost_normalization=False,
        enable_memory_constraints=False,
        enable_warp_constraints=True,
        verbose=True,
    )

    if result:
        print(f"\n{result}")
        viz = visualize_schedule(graph, result.joint_result,
                                title="Simple GEMM Pipeline")
        print(viz)
        return True
    else:
        print("FAILED: No solution found!")
        return False


def test_modulo_scheduling_only():
    """Test Phase 1 (modulo scheduling) in isolation."""
    print("\n" + "=" * 70)
    print("TEST: Modulo Scheduling Only (Phase 1)")
    print("=" * 70)

    graph = flash_attention_forward_simplified()

    # Try increasing I values
    for target_I in range(1, 6):
        print(f"\n  Trying I = {target_I}...")
        sched = optimal_modulo_schedule(graph, target_I=target_I, verbose=False)
        if sched:
            valid, violations = validate_schedule(graph, sched)
            print(f"    Found: M={sched.schedule}, L={sched.length}, copies={sched.num_copies}")
            print(f"    Valid: {valid}")
            if not valid:
                for v in violations:
                    print(f"      {v}")
        else:
            print(f"    Infeasible")

    return True


if __name__ == "__main__":
    print("β•”" + "═" * 68 + "β•—")
    print("β•‘" + " Twill Test Suite ".center(68) + "β•‘")
    print("β•‘" + " Optimal SWP + WS for Tensor Core GPUs ".center(68) + "β•‘")
    print("β•‘" + " (arXiv:2512.18134) ".center(68) + "β•‘")
    print("β•š" + "═" * 68 + "╝")

    results = {}
    
    start = time.time()

    # Run tests
    results["Cost Normalization"] = test_cost_normalization()
    results["Modulo Scheduling Only"] = test_modulo_scheduling_only()
    results["Simplified FA (Figure 1)"] = test_simplified_flash_attention()
    results["Simple GEMM"] = test_simple_gemm()
    results["Hopper FMHA Forward"] = test_hopper_fmha_forward()
    results["Blackwell FMHA Forward"] = test_blackwell_fmha_forward()

    elapsed = time.time() - start

    # Summary
    print("\n" + "=" * 70)
    print("TEST SUMMARY")
    print("=" * 70)
    for name, passed in results.items():
        status = "βœ“ PASS" if passed else "βœ— FAIL"
        print(f"  {status}  {name}")
    print(f"\nTotal time: {elapsed:.2f}s")
    print(f"Passed: {sum(results.values())}/{len(results)}")
    
    sys.exit(0 if all(results.values()) else 1)