twill-swp-ws / test_twill.py
AshenNav's picture
Upload test_twill.py with huggingface_hub
67e6b48 verified
#!/usr/bin/env python3
"""
Test suite for Twill: Optimal Software Pipelining and Warp Specialization.
Reproduces the key results from the paper:
1. Simplified FA schedule (Figure 1) - GEMM and EXP have equal cost
2. Full FMHA forward on Hopper (Section 6.2.1) - rediscovers FA3 pipeline
3. Full FMHA forward on Blackwell (Section 6.2.2) - rediscovers FA4 strategy
4. Simple GEMM pipeline - load-compute overlap
"""
import sys
import time
import numpy as np
sys.path.insert(0, '/app')
from twill.graph import (
DependenceGraph, Instruction, DependenceEdge,
hopper_machine, blackwell_machine, make_rrt,
)
from twill.cost_normalization import normalize_costs
from twill.modulo_scheduler import (
optimal_modulo_schedule, compute_modular_rrt, validate_schedule,
)
from twill.smt_joint import swp_and_ws
from twill.twill_solver import twill_solve
from twill.codegen import generate_pipelined_code, generate_pseudocode, generate_cuda_skeleton
from twill.visualization import visualize_schedule, print_modular_rrt
from twill.kernels import (
flash_attention_forward_simplified,
flash_attention_forward_hopper,
flash_attention_forward_blackwell,
simple_gemm_pipeline,
)
def test_cost_normalization():
"""Test cost normalization (Section 5.2)."""
print("\n" + "=" * 70)
print("TEST: Cost Normalization (Section 5.2)")
print("=" * 70)
# Example from the paper: GEMM ~1000 cycles on Hopper
costs = {"WGMMA": 1000, "EXP": 1000, "TMA": 500}
normalized, F = normalize_costs(costs, U=10)
print(f"Original: {costs}")
print(f"Normalized: {normalized} (F={F})")
# Check ratio preservation
for k1 in costs:
for k2 in costs:
if costs[k2] > 0 and normalized[k2] > 0:
orig_ratio = costs[k1] / costs[k2]
norm_ratio = normalized[k1] / normalized[k2]
print(f" Ratio {k1}/{k2}: orig={orig_ratio:.3f}, norm={norm_ratio:.3f}")
# Larger example
costs2 = {"WGMMA": 1024, "EXP": 960, "TMA_K": 512, "TMA_V": 512, "TMEM_R": 128, "TMEM_W": 128}
normalized2, F2 = normalize_costs(costs2, U=20)
print(f"\nOriginal: {costs2}")
print(f"Normalized: {normalized2} (F={F2})")
print("\nβœ“ Cost normalization test passed")
return True
def test_simplified_flash_attention():
"""Test the simplified FA example from Figure 1.
Expected result from the paper:
- I = 2 (one GEMM and one EXP per 2 cycles)
- L = 4 (schedule length)
- ceil(L/I) = 2 copies overlapped
- S is extracted into prologue
"""
print("\n" + "=" * 70)
print("TEST: Simplified Flash Attention (Figure 1)")
print("=" * 70)
graph = flash_attention_forward_simplified()
print(f"Graph: {graph}")
print(f"Min I (resource bound): {graph.compute_min_initiation_interval()}")
# Phase 1: Modulo scheduling
print("\n--- Phase 1: Modulo Scheduling ---")
schedule = optimal_modulo_schedule(graph, target_I=2, verbose=False)
if schedule is None:
print("Phase 1 FAILED for I=2, trying I=3...")
schedule = optimal_modulo_schedule(graph, target_I=3, verbose=False)
assert schedule is not None, "No modulo schedule found!"
print(f"Schedule: {schedule}")
valid, violations = validate_schedule(graph, schedule)
print(f"Valid: {valid}")
if not valid:
for v in violations:
print(f" VIOLATION: {v}")
print(f"\n{print_modular_rrt(graph, schedule)}")
# Full Twill solve
print("\n--- Full Twill Solve ---")
result = twill_solve(
graph,
max_I=5,
enable_cost_normalization=False, # Costs already normalized (1 cycle each)
enable_memory_constraints=False, # Simplified example
enable_warp_constraints=True,
verbose=True,
)
if result:
print(f"\n{result}")
# Visualization
viz = visualize_schedule(graph, result.joint_result,
output_path="/app/simplified_fa_schedule.png",
title="Simplified Flash Attention (Figure 1)")
print(viz)
# Code generation
pseudo = generate_pseudocode(graph, result.joint_result)
print(f"\n--- Generated Pseudocode ---")
print(pseudo)
# CUDA skeleton
cuda = generate_cuda_skeleton(graph, result.joint_result)
print(f"\n--- Generated CUDA Skeleton ---")
print(cuda)
return True
else:
print("FAILED: No solution found!")
return False
def test_hopper_fmha_forward():
"""Test full FMHA forward on Hopper (Section 6.2.1).
Expected: Twill rediscovers FA3's SWP strategy:
- S=gemm(Q,K[0]) extracted into prologue
- Ping-pong scheduling between EXP and TC
- TMA loads on variable-latency (producer) warp
"""
print("\n" + "=" * 70)
print("TEST: FMHA Forward on Hopper (Section 6.2.1)")
print("=" * 70)
graph = flash_attention_forward_hopper()
print(f"Graph: {graph}")
print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}")
print(f"Min I: {graph.compute_min_initiation_interval()}")
result = twill_solve(
graph,
max_I=10,
enable_cost_normalization=False,
enable_memory_constraints=False,
enable_warp_constraints=True,
verbose=True,
)
if result:
print(f"\n{result}")
viz = visualize_schedule(graph, result.joint_result,
output_path="/app/hopper_fmha_schedule.png",
title="FMHA Forward on Hopper")
print(viz)
pseudo = generate_pseudocode(graph, result.joint_result)
print(f"\n--- Generated Pseudocode ---")
print(pseudo)
# Verify key properties:
wa = result.warp_assignment
print(f"\n--- Verification ---")
# TMA loads should be on variable-latency warp (W_vl = 0)
k_warp = wa.warp_of("K_load")
v_warp = wa.warp_of("V_load")
print(f"K_load on warp {k_warp} (expected: 0/variable-latency)")
print(f"V_load on warp {v_warp} (expected: 0/variable-latency)")
# TC operations should be on compute warps (not warp 0)
s_warp = wa.warp_of("S")
o_warp = wa.warp_of("O")
print(f"S (GEMM) on warp {s_warp} (expected: compute warp)")
print(f"O (GEMM) on warp {o_warp} (expected: compute warp)")
# EXP should be on a compute warp
p_warp = wa.warp_of("P")
print(f"P (EXP) on warp {p_warp} (expected: compute warp)")
return True
else:
print("FAILED: No solution found!")
return False
def test_blackwell_fmha_forward():
"""Test FMHA forward on Blackwell (Section 6.2.2).
Expected: Twill rediscovers FA4's strategy:
- Variable-latency ops (TMA) on separate warp
- TC GEMMs on compute warps
- TMEM read/write for accumulator management
- Different pipeline structure than Hopper
"""
print("\n" + "=" * 70)
print("TEST: FMHA Forward on Blackwell (Section 6.2.2)")
print("=" * 70)
graph = flash_attention_forward_blackwell()
print(f"Graph: {graph}")
print(f"Instructions: {[(v.name, 'var_lat' if v.variable_latency else 'fixed') for v in graph.V]}")
print(f"Min I: {graph.compute_min_initiation_interval()}")
result = twill_solve(
graph,
max_I=15,
enable_cost_normalization=False,
enable_memory_constraints=False,
enable_warp_constraints=True,
verbose=True,
)
if result:
print(f"\n{result}")
viz = visualize_schedule(graph, result.joint_result,
output_path="/app/blackwell_fmha_schedule.png",
title="FMHA Forward on Blackwell")
print(viz)
pseudo = generate_pseudocode(graph, result.joint_result)
print(f"\n--- Generated Pseudocode ---")
print(pseudo)
cuda = generate_cuda_skeleton(graph, result.joint_result)
print(f"\n--- Generated CUDA Skeleton ---")
print(cuda)
return True
else:
print("FAILED: No solution found!")
return False
def test_simple_gemm():
"""Test simple GEMM with load-compute overlap."""
print("\n" + "=" * 70)
print("TEST: Simple GEMM Pipeline")
print("=" * 70)
graph = simple_gemm_pipeline()
print(f"Graph: {graph}")
result = twill_solve(
graph,
max_I=6,
enable_cost_normalization=False,
enable_memory_constraints=False,
enable_warp_constraints=True,
verbose=True,
)
if result:
print(f"\n{result}")
viz = visualize_schedule(graph, result.joint_result,
title="Simple GEMM Pipeline")
print(viz)
return True
else:
print("FAILED: No solution found!")
return False
def test_modulo_scheduling_only():
"""Test Phase 1 (modulo scheduling) in isolation."""
print("\n" + "=" * 70)
print("TEST: Modulo Scheduling Only (Phase 1)")
print("=" * 70)
graph = flash_attention_forward_simplified()
# Try increasing I values
for target_I in range(1, 6):
print(f"\n Trying I = {target_I}...")
sched = optimal_modulo_schedule(graph, target_I=target_I, verbose=False)
if sched:
valid, violations = validate_schedule(graph, sched)
print(f" Found: M={sched.schedule}, L={sched.length}, copies={sched.num_copies}")
print(f" Valid: {valid}")
if not valid:
for v in violations:
print(f" {v}")
else:
print(f" Infeasible")
return True
if __name__ == "__main__":
print("β•”" + "═" * 68 + "β•—")
print("β•‘" + " Twill Test Suite ".center(68) + "β•‘")
print("β•‘" + " Optimal SWP + WS for Tensor Core GPUs ".center(68) + "β•‘")
print("β•‘" + " (arXiv:2512.18134) ".center(68) + "β•‘")
print("β•š" + "═" * 68 + "╝")
results = {}
start = time.time()
# Run tests
results["Cost Normalization"] = test_cost_normalization()
results["Modulo Scheduling Only"] = test_modulo_scheduling_only()
results["Simplified FA (Figure 1)"] = test_simplified_flash_attention()
results["Simple GEMM"] = test_simple_gemm()
results["Hopper FMHA Forward"] = test_hopper_fmha_forward()
results["Blackwell FMHA Forward"] = test_blackwell_fmha_forward()
elapsed = time.time() - start
# Summary
print("\n" + "=" * 70)
print("TEST SUMMARY")
print("=" * 70)
for name, passed in results.items():
status = "βœ“ PASS" if passed else "βœ— FAIL"
print(f" {status} {name}")
print(f"\nTotal time: {elapsed:.2f}s")
print(f"Passed: {sum(results.values())}/{len(results)}")
sys.exit(0 if all(results.values()) else 1)