#!/usr/bin/env python3 """ Test suite for Twill: Optimal Software Pipelining and Warp Specialization. Reproduces the key results from the paper: 1. Simplified FA schedule (Figure 1) - GEMM and EXP have equal cost 2. Full FMHA forward on Hopper (Section 6.2.1) - rediscovers FA3 pipeline 3. Full FMHA forward on Blackwell (Section 6.2.2) - rediscovers FA4 strategy 4. Simple GEMM pipeline - load-compute overlap """ import sys import time import numpy as np sys.path.insert(0, '/app') from twill.graph import ( DependenceGraph, Instruction, DependenceEdge, hopper_machine, blackwell_machine, make_rrt, ) from twill.cost_normalization import normalize_costs from twill.modulo_scheduler import ( optimal_modulo_schedule, compute_modular_rrt, validate_schedule, ) from twill.smt_joint import swp_and_ws from twill.twill_solver import twill_solve from twill.codegen import generate_pipelined_code, generate_pseudocode, generate_cuda_skeleton from twill.visualization import visualize_schedule, print_modular_rrt from twill.kernels import ( flash_attention_forward_simplified, flash_attention_forward_hopper, flash_attention_forward_blackwell, simple_gemm_pipeline, ) def test_cost_normalization(): """Test cost normalization (Section 5.2).""" print("\n" + "=" * 70) print("TEST: Cost Normalization (Section 5.2)") print("=" * 70) # Example from the paper: GEMM ~1000 cycles on Hopper costs = {"WGMMA": 1000, "EXP": 1000, "TMA": 500} normalized, F = normalize_costs(costs, U=10) print(f"Original: {costs}") print(f"Normalized: {normalized} (F={F})") # Check ratio preservation for k1 in costs: for k2 in costs: if costs[k2] > 0 and normalized[k2] > 0: orig_ratio = costs[k1] / costs[k2] norm_ratio = normalized[k1] / normalized[k2] print(f" Ratio {k1}/{k2}: orig={orig_ratio:.3f}, norm={norm_ratio:.3f}") # Larger example costs2 = {"WGMMA": 1024, "EXP": 960, "TMA_K": 512, "TMA_V": 512, "TMEM_R": 128, "TMEM_W": 128} normalized2, F2 = normalize_costs(costs2, U=20) print(f"\nOriginal: {costs2}") print(f"Normalized: {normalized2} (F={F2})") print("\n✓ Cost normalization test passed") return True def test_simplified_flash_attention(): """Test the simplified FA example from Figure 1. Expected result from the paper: - I = 2 (one GEMM and one EXP per 2 cycles) - L = 4 (schedule length) - ceil(L/I) = 2 copies overlapped - S is extracted into prologue """ print("\n" + "=" * 70) print("TEST: Simplified Flash Attention (Figure 1)") print("=" * 70) graph = flash_attention_forward_simplified() print(f"Graph: {graph}") print(f"Min I (resource bound): {graph.compute_min_initiation_interval()}") # Phase 1: Modulo scheduling print("\n--- Phase 1: Modulo Scheduling ---") schedule = optimal_modulo_schedule(graph, target_I=2, verbose=False) if schedule is None: print("Phase 1 FAILED for I=2, trying I=3...") schedule = optimal_modulo_schedule(graph, target_I=3, verbose=False) assert schedule is not None, "No modulo schedule found!" print(f"Schedule: {schedule}") valid, violations = validate_schedule(graph, schedule) print(f"Valid: {valid}") if not valid: for v in violations: print(f" VIOLATION: {v}") print(f"\n{print_modular_rrt(graph, schedule)}") # Full Twill solve print("\n--- Full Twill Solve ---") result = twill_solve( graph, max_I=5, enable_cost_normalization=False, # Costs already normalized (1 cycle each) enable_memory_constraints=False, # Simplified example enable_warp_constraints=True, verbose=True, ) if result: print(f"\n{result}") # Visualization viz = visualize_schedule(graph, result.joint_result, output_path="/app/simplified_fa_schedule.png", title="Simplified Flash Attention (Figure 1)") print(viz) # Code generation pseudo = generate_pseudocode(graph, result.joint_result) print(f"\n--- Generated Pseudocode ---") print(pseudo) # CUDA skeleton cuda = generate_cuda_skeleton(graph, result.joint_result) print(f"\n--- Generated CUDA Skeleton ---") print(cuda) return True else: print("FAILED: No solution found!") return False def test_hopper_fmha_forward(): """Test full FMHA forward on Hopper (Section 6.2.1). Expected: Twill rediscovers FA3's SWP strategy: - S=gemm(Q,K[0]) extracted into prologue - Ping-pong scheduling between EXP and TC - TMA loads on variable-latency (producer) warp """ print("\n" + "=" * 70) print("TEST: FMHA Forward on Hopper (Section 6.2.1)") print("=" * 70) graph = flash_attention_forward_hopper() print(f"Graph: {graph}") print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}") print(f"Min I: {graph.compute_min_initiation_interval()}") result = twill_solve( graph, max_I=10, enable_cost_normalization=False, enable_memory_constraints=False, enable_warp_constraints=True, verbose=True, ) if result: print(f"\n{result}") viz = visualize_schedule(graph, result.joint_result, output_path="/app/hopper_fmha_schedule.png", title="FMHA Forward on Hopper") print(viz) pseudo = generate_pseudocode(graph, result.joint_result) print(f"\n--- Generated Pseudocode ---") print(pseudo) # Verify key properties: wa = result.warp_assignment print(f"\n--- Verification ---") # TMA loads should be on variable-latency warp (W_vl = 0) k_warp = wa.warp_of("K_load") v_warp = wa.warp_of("V_load") print(f"K_load on warp {k_warp} (expected: 0/variable-latency)") print(f"V_load on warp {v_warp} (expected: 0/variable-latency)") # TC operations should be on compute warps (not warp 0) s_warp = wa.warp_of("S") o_warp = wa.warp_of("O") print(f"S (GEMM) on warp {s_warp} (expected: compute warp)") print(f"O (GEMM) on warp {o_warp} (expected: compute warp)") # EXP should be on a compute warp p_warp = wa.warp_of("P") print(f"P (EXP) on warp {p_warp} (expected: compute warp)") return True else: print("FAILED: No solution found!") return False def test_blackwell_fmha_forward(): """Test FMHA forward on Blackwell (Section 6.2.2). Expected: Twill rediscovers FA4's strategy: - Variable-latency ops (TMA) on separate warp - TC GEMMs on compute warps - TMEM read/write for accumulator management - Different pipeline structure than Hopper """ print("\n" + "=" * 70) print("TEST: FMHA Forward on Blackwell (Section 6.2.2)") print("=" * 70) graph = flash_attention_forward_blackwell() print(f"Graph: {graph}") print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}") print(f"Min I: {graph.compute_min_initiation_interval()}") result = twill_solve( graph, max_I=15, enable_cost_normalization=False, enable_memory_constraints=False, enable_warp_constraints=True, verbose=True, ) if result: print(f"\n{result}") viz = visualize_schedule(graph, result.joint_result, output_path="/app/blackwell_fmha_schedule.png", title="FMHA Forward on Blackwell") print(viz) pseudo = generate_pseudocode(graph, result.joint_result) print(f"\n--- Generated Pseudocode ---") print(pseudo) cuda = generate_cuda_skeleton(graph, result.joint_result) print(f"\n--- Generated CUDA Skeleton ---") print(cuda) return True else: print("FAILED: No solution found!") return False def test_simple_gemm(): """Test simple GEMM with load-compute overlap.""" print("\n" + "=" * 70) print("TEST: Simple GEMM Pipeline") print("=" * 70) graph = simple_gemm_pipeline() print(f"Graph: {graph}") result = twill_solve( graph, max_I=6, enable_cost_normalization=False, enable_memory_constraints=False, enable_warp_constraints=True, verbose=True, ) if result: print(f"\n{result}") viz = visualize_schedule(graph, result.joint_result, title="Simple GEMM Pipeline") print(viz) return True else: print("FAILED: No solution found!") return False def test_modulo_scheduling_only(): """Test Phase 1 (modulo scheduling) in isolation.""" print("\n" + "=" * 70) print("TEST: Modulo Scheduling Only (Phase 1)") print("=" * 70) graph = flash_attention_forward_simplified() # Try increasing I values for target_I in range(1, 6): print(f"\n Trying I = {target_I}...") sched = optimal_modulo_schedule(graph, target_I=target_I, verbose=False) if sched: valid, violations = validate_schedule(graph, sched) print(f" Found: M={sched.schedule}, L={sched.length}, copies={sched.num_copies}") print(f" Valid: {valid}") if not valid: for v in violations: print(f" {v}") else: print(f" Infeasible") return True if __name__ == "__main__": print("╔" + "═" * 68 + "╗") print("║" + " Twill Test Suite ".center(68) + "║") print("║" + " Optimal SWP + WS for Tensor Core GPUs ".center(68) + "║") print("║" + " (arXiv:2512.18134) ".center(68) + "║") print("╚" + "═" * 68 + "╝") results = {} start = time.time() # Run tests results["Cost Normalization"] = test_cost_normalization() results["Modulo Scheduling Only"] = test_modulo_scheduling_only() results["Simplified FA (Figure 1)"] = test_simplified_flash_attention() results["Simple GEMM"] = test_simple_gemm() results["Hopper FMHA Forward"] = test_hopper_fmha_forward() results["Blackwell FMHA Forward"] = test_blackwell_fmha_forward() elapsed = time.time() - start # Summary print("\n" + "=" * 70) print("TEST SUMMARY") print("=" * 70) for name, passed in results.items(): status = "✓ PASS" if passed else "✗ FAIL" print(f" {status} {name}") print(f"\nTotal time: {elapsed:.2f}s") print(f"Passed: {sum(results.values())}/{len(results)}") sys.exit(0 if all(results.values()) else 1)