| |
| """ |
| Test suite for Twill: Optimal Software Pipelining and Warp Specialization. |
| |
| Reproduces the key results from the paper: |
| 1. Simplified FA schedule (Figure 1) - GEMM and EXP have equal cost |
| 2. Full FMHA forward on Hopper (Section 6.2.1) - rediscovers FA3 pipeline |
| 3. Full FMHA forward on Blackwell (Section 6.2.2) - rediscovers FA4 strategy |
| 4. Simple GEMM pipeline - load-compute overlap |
| """ |
|
|
| import sys |
| import time |
| import numpy as np |
|
|
| sys.path.insert(0, '/app') |
|
|
| from twill.graph import ( |
| DependenceGraph, Instruction, DependenceEdge, |
| hopper_machine, blackwell_machine, make_rrt, |
| ) |
| from twill.cost_normalization import normalize_costs |
| from twill.modulo_scheduler import ( |
| optimal_modulo_schedule, compute_modular_rrt, validate_schedule, |
| ) |
| from twill.smt_joint import swp_and_ws |
| from twill.twill_solver import twill_solve |
| from twill.codegen import generate_pipelined_code, generate_pseudocode, generate_cuda_skeleton |
| from twill.visualization import visualize_schedule, print_modular_rrt |
| from twill.kernels import ( |
| flash_attention_forward_simplified, |
| flash_attention_forward_hopper, |
| flash_attention_forward_blackwell, |
| simple_gemm_pipeline, |
| ) |
|
|
|
|
| def test_cost_normalization(): |
| """Test cost normalization (Section 5.2).""" |
| print("\n" + "=" * 70) |
| print("TEST: Cost Normalization (Section 5.2)") |
| print("=" * 70) |
|
|
| |
| costs = {"WGMMA": 1000, "EXP": 1000, "TMA": 500} |
| normalized, F = normalize_costs(costs, U=10) |
| print(f"Original: {costs}") |
| print(f"Normalized: {normalized} (F={F})") |
| |
| |
| for k1 in costs: |
| for k2 in costs: |
| if costs[k2] > 0 and normalized[k2] > 0: |
| orig_ratio = costs[k1] / costs[k2] |
| norm_ratio = normalized[k1] / normalized[k2] |
| print(f" Ratio {k1}/{k2}: orig={orig_ratio:.3f}, norm={norm_ratio:.3f}") |
|
|
| |
| costs2 = {"WGMMA": 1024, "EXP": 960, "TMA_K": 512, "TMA_V": 512, "TMEM_R": 128, "TMEM_W": 128} |
| normalized2, F2 = normalize_costs(costs2, U=20) |
| print(f"\nOriginal: {costs2}") |
| print(f"Normalized: {normalized2} (F={F2})") |
|
|
| print("\nβ Cost normalization test passed") |
| return True |
|
|
|
|
| def test_simplified_flash_attention(): |
| """Test the simplified FA example from Figure 1. |
| |
| Expected result from the paper: |
| - I = 2 (one GEMM and one EXP per 2 cycles) |
| - L = 4 (schedule length) |
| - ceil(L/I) = 2 copies overlapped |
| - S is extracted into prologue |
| """ |
| print("\n" + "=" * 70) |
| print("TEST: Simplified Flash Attention (Figure 1)") |
| print("=" * 70) |
|
|
| graph = flash_attention_forward_simplified() |
| print(f"Graph: {graph}") |
| print(f"Min I (resource bound): {graph.compute_min_initiation_interval()}") |
|
|
| |
| print("\n--- Phase 1: Modulo Scheduling ---") |
| schedule = optimal_modulo_schedule(graph, target_I=2, verbose=False) |
| |
| if schedule is None: |
| print("Phase 1 FAILED for I=2, trying I=3...") |
| schedule = optimal_modulo_schedule(graph, target_I=3, verbose=False) |
| |
| assert schedule is not None, "No modulo schedule found!" |
| print(f"Schedule: {schedule}") |
| |
| valid, violations = validate_schedule(graph, schedule) |
| print(f"Valid: {valid}") |
| if not valid: |
| for v in violations: |
| print(f" VIOLATION: {v}") |
| |
| print(f"\n{print_modular_rrt(graph, schedule)}") |
|
|
| |
| print("\n--- Full Twill Solve ---") |
| result = twill_solve( |
| graph, |
| max_I=5, |
| enable_cost_normalization=False, |
| enable_memory_constraints=False, |
| enable_warp_constraints=True, |
| verbose=True, |
| ) |
|
|
| if result: |
| print(f"\n{result}") |
| |
| |
| viz = visualize_schedule(graph, result.joint_result, |
| output_path="/app/simplified_fa_schedule.png", |
| title="Simplified Flash Attention (Figure 1)") |
| print(viz) |
| |
| |
| pseudo = generate_pseudocode(graph, result.joint_result) |
| print(f"\n--- Generated Pseudocode ---") |
| print(pseudo) |
|
|
| |
| cuda = generate_cuda_skeleton(graph, result.joint_result) |
| print(f"\n--- Generated CUDA Skeleton ---") |
| print(cuda) |
| |
| return True |
| else: |
| print("FAILED: No solution found!") |
| return False |
|
|
|
|
| def test_hopper_fmha_forward(): |
| """Test full FMHA forward on Hopper (Section 6.2.1). |
| |
| Expected: Twill rediscovers FA3's SWP strategy: |
| - S=gemm(Q,K[0]) extracted into prologue |
| - Ping-pong scheduling between EXP and TC |
| - TMA loads on variable-latency (producer) warp |
| """ |
| print("\n" + "=" * 70) |
| print("TEST: FMHA Forward on Hopper (Section 6.2.1)") |
| print("=" * 70) |
|
|
| graph = flash_attention_forward_hopper() |
| print(f"Graph: {graph}") |
| print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}") |
| print(f"Min I: {graph.compute_min_initiation_interval()}") |
|
|
| result = twill_solve( |
| graph, |
| max_I=10, |
| enable_cost_normalization=False, |
| enable_memory_constraints=False, |
| enable_warp_constraints=True, |
| verbose=True, |
| ) |
|
|
| if result: |
| print(f"\n{result}") |
| |
| viz = visualize_schedule(graph, result.joint_result, |
| output_path="/app/hopper_fmha_schedule.png", |
| title="FMHA Forward on Hopper") |
| print(viz) |
| |
| pseudo = generate_pseudocode(graph, result.joint_result) |
| print(f"\n--- Generated Pseudocode ---") |
| print(pseudo) |
|
|
| |
| wa = result.warp_assignment |
| print(f"\n--- Verification ---") |
| |
| |
| k_warp = wa.warp_of("K_load") |
| v_warp = wa.warp_of("V_load") |
| print(f"K_load on warp {k_warp} (expected: 0/variable-latency)") |
| print(f"V_load on warp {v_warp} (expected: 0/variable-latency)") |
| |
| |
| s_warp = wa.warp_of("S") |
| o_warp = wa.warp_of("O") |
| print(f"S (GEMM) on warp {s_warp} (expected: compute warp)") |
| print(f"O (GEMM) on warp {o_warp} (expected: compute warp)") |
| |
| |
| p_warp = wa.warp_of("P") |
| print(f"P (EXP) on warp {p_warp} (expected: compute warp)") |
| |
| return True |
| else: |
| print("FAILED: No solution found!") |
| return False |
|
|
|
|
| def test_blackwell_fmha_forward(): |
| """Test FMHA forward on Blackwell (Section 6.2.2). |
| |
| Expected: Twill rediscovers FA4's strategy: |
| - Variable-latency ops (TMA) on separate warp |
| - TC GEMMs on compute warps |
| - TMEM read/write for accumulator management |
| - Different pipeline structure than Hopper |
| """ |
| print("\n" + "=" * 70) |
| print("TEST: FMHA Forward on Blackwell (Section 6.2.2)") |
| print("=" * 70) |
|
|
| graph = flash_attention_forward_blackwell() |
| print(f"Graph: {graph}") |
| print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}") |
| print(f"Min I: {graph.compute_min_initiation_interval()}") |
|
|
| result = twill_solve( |
| graph, |
| max_I=15, |
| enable_cost_normalization=False, |
| enable_memory_constraints=False, |
| enable_warp_constraints=True, |
| verbose=True, |
| ) |
|
|
| if result: |
| print(f"\n{result}") |
| |
| viz = visualize_schedule(graph, result.joint_result, |
| output_path="/app/blackwell_fmha_schedule.png", |
| title="FMHA Forward on Blackwell") |
| print(viz) |
| |
| pseudo = generate_pseudocode(graph, result.joint_result) |
| print(f"\n--- Generated Pseudocode ---") |
| print(pseudo) |
|
|
| cuda = generate_cuda_skeleton(graph, result.joint_result) |
| print(f"\n--- Generated CUDA Skeleton ---") |
| print(cuda) |
| |
| return True |
| else: |
| print("FAILED: No solution found!") |
| return False |
|
|
|
|
| def test_simple_gemm(): |
| """Test simple GEMM with load-compute overlap.""" |
| print("\n" + "=" * 70) |
| print("TEST: Simple GEMM Pipeline") |
| print("=" * 70) |
|
|
| graph = simple_gemm_pipeline() |
| print(f"Graph: {graph}") |
|
|
| result = twill_solve( |
| graph, |
| max_I=6, |
| enable_cost_normalization=False, |
| enable_memory_constraints=False, |
| enable_warp_constraints=True, |
| verbose=True, |
| ) |
|
|
| if result: |
| print(f"\n{result}") |
| viz = visualize_schedule(graph, result.joint_result, |
| title="Simple GEMM Pipeline") |
| print(viz) |
| return True |
| else: |
| print("FAILED: No solution found!") |
| return False |
|
|
|
|
| def test_modulo_scheduling_only(): |
| """Test Phase 1 (modulo scheduling) in isolation.""" |
| print("\n" + "=" * 70) |
| print("TEST: Modulo Scheduling Only (Phase 1)") |
| print("=" * 70) |
|
|
| graph = flash_attention_forward_simplified() |
|
|
| |
| for target_I in range(1, 6): |
| print(f"\n Trying I = {target_I}...") |
| sched = optimal_modulo_schedule(graph, target_I=target_I, verbose=False) |
| if sched: |
| valid, violations = validate_schedule(graph, sched) |
| print(f" Found: M={sched.schedule}, L={sched.length}, copies={sched.num_copies}") |
| print(f" Valid: {valid}") |
| if not valid: |
| for v in violations: |
| print(f" {v}") |
| else: |
| print(f" Infeasible") |
|
|
| return True |
|
|
|
|
| if __name__ == "__main__": |
| print("β" + "β" * 68 + "β") |
| print("β" + " Twill Test Suite ".center(68) + "β") |
| print("β" + " Optimal SWP + WS for Tensor Core GPUs ".center(68) + "β") |
| print("β" + " (arXiv:2512.18134) ".center(68) + "β") |
| print("β" + "β" * 68 + "β") |
|
|
| results = {} |
| |
| start = time.time() |
|
|
| |
| results["Cost Normalization"] = test_cost_normalization() |
| results["Modulo Scheduling Only"] = test_modulo_scheduling_only() |
| results["Simplified FA (Figure 1)"] = test_simplified_flash_attention() |
| results["Simple GEMM"] = test_simple_gemm() |
| results["Hopper FMHA Forward"] = test_hopper_fmha_forward() |
| results["Blackwell FMHA Forward"] = test_blackwell_fmha_forward() |
|
|
| elapsed = time.time() - start |
|
|
| |
| print("\n" + "=" * 70) |
| print("TEST SUMMARY") |
| print("=" * 70) |
| for name, passed in results.items(): |
| status = "β PASS" if passed else "β FAIL" |
| print(f" {status} {name}") |
| print(f"\nTotal time: {elapsed:.2f}s") |
| print(f"Passed: {sum(results.values())}/{len(results)}") |
| |
| sys.exit(0 if all(results.values()) else 1) |
|
|