twill-swp-ws / test_gaus.py
AshenNav's picture
Upload test_gaus.py with huggingface_hub
1717684 verified
#!/usr/bin/env python3
"""
Test suite for GauS: Differentiable Scheduling via Gaussian Reparameterization.
Integrated with Twill's kernel descriptions for direct comparison.
Tests:
1. Basic Gaussian reparameterization (P_i^d computation)
2. Regular scheduling on simple DAGs
3. Modulo scheduling on Twill's FMHA kernels (compare to Twill ILP+SMT)
4. Scalability: 100-node, 1000-node random DAGs
"""
import sys
import time
import numpy as np
sys.path.insert(0, '/app')
from twill.gaus_solver import (
GauSSolver, GausGraph, GausResult,
compute_asap, compute_alap, gaussian_cdf,
twill_graph_to_gaus, gaus_solve_twill_graph,
generate_random_dag,
)
from twill.kernels import (
flash_attention_forward_simplified,
flash_attention_forward_hopper,
flash_attention_forward_blackwell,
simple_gemm_pipeline,
)
from twill.twill_solver import twill_solve
def test_gaussian_cdf():
"""Test basic Gaussian CDF computation."""
print("\n" + "=" * 70)
print("TEST: Gaussian CDF + P_i^d Computation")
print("=" * 70)
import torch
# CDF at 0 should be 0.5
assert abs(gaussian_cdf(torch.tensor(0.0)).item() - 0.5) < 1e-6
# CDF at large positive should be ~1
assert abs(gaussian_cdf(torch.tensor(5.0)).item() - 1.0) < 1e-4
# CDF at large negative should be ~0
assert abs(gaussian_cdf(torch.tensor(-5.0)).item()) < 1e-4
# Test P_i^d: single node at ΞΌ=2, Οƒ=0.5
graph = GausGraph(num_nodes=1, edges=[])
solver = GauSSolver(graph, D=6)
mu = torch.tensor([2.0])
sigma = torch.tensor([0.5])
P = solver._compute_P(mu, sigma)
print(f" ΞΌ=2.0, Οƒ=0.5, D=6")
print(f" P = {P[0].detach().numpy().round(4)}")
print(f" Sum P = {P[0].sum().item():.6f} (should be ~1.0)")
print(f" Argmax P = {P[0].argmax().item()} (should be 2)")
assert P[0].argmax().item() == 2, "Peak should be at ΞΌ=2"
assert abs(P[0].sum().item() - 1.0) < 0.01, "Probabilities should sum to ~1"
# As Οƒ β†’ 0, P should be a delta at round(ΞΌ)
sigma_small = torch.tensor([0.01])
P_sharp = solver._compute_P(mu, sigma_small)
print(f" Οƒ=0.01: P[2]={P_sharp[0, 2].item():.6f} (should be ~1.0)")
assert P_sharp[0, 2].item() > 0.99
print("βœ“ Gaussian CDF test passed")
return True
def test_asap_alap():
"""Test ASAP/ALAP computation."""
print("\n" + "=" * 70)
print("TEST: ASAP / ALAP Computation")
print("=" * 70)
# Chain: 0 -> 1 -> 2
graph = GausGraph(
num_nodes=3,
edges=[(0, 1), (1, 2)],
node_names=["A", "B", "C"],
)
asap = compute_asap(graph)
alap = compute_alap(graph, D=5)
print(f" Chain A->B->C, D=5")
print(f" ASAP: {asap}") # Expected: [0, 1, 2]
print(f" ALAP: {alap}") # Expected: [2, 3, 4]
assert list(asap) == [0, 1, 2], f"ASAP wrong: {asap}"
assert list(alap) == [2, 3, 4], f"ALAP wrong: {alap}"
print("βœ“ ASAP/ALAP test passed")
return True
def test_regular_scheduling():
"""Test regular (non-modulo) scheduling."""
print("\n" + "=" * 70)
print("TEST: Regular Scheduling (Formulation A)")
print("=" * 70)
# Diamond: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
graph = GausGraph(
num_nodes=4,
edges=[(0, 1), (0, 2), (1, 3), (2, 3)],
resource_weights=np.array([1, 1, 1, 1], dtype=np.float64),
node_names=["A", "B", "C", "D"],
)
solver = GauSSolver(graph, D=6, lr=0.05)
result = solver.solve_regular(max_iters=500, legalize_every=100, verbose=True)
print(f"\n Result: {result}")
# Verify dependencies
s = result.schedule
assert s[1] > s[0], f"B must be after A: {s[1]} > {s[0]}"
assert s[2] > s[0], f"C must be after A: {s[2]} > {s[0]}"
assert s[3] > s[1], f"D must be after B: {s[3]} > {s[1]}"
assert s[3] > s[2], f"D must be after C: {s[3]} > {s[2]}"
assert result.is_feasible, "Schedule should be feasible"
print("βœ“ Regular scheduling test passed")
return True
def test_modulo_scheduling_simple():
"""Test modulo scheduling on a simple loop body."""
print("\n" + "=" * 70)
print("TEST: Modulo Scheduling (Formulation C) β€” Simple")
print("=" * 70)
# Simple loop: A -> B -> C, with C -> C loop-carried
graph = GausGraph(
num_nodes=3,
edges=[(0, 1), (1, 2)],
back_edges=[(2, 2, 1)], # C -> C with Ξ΄=1
resource_weights=np.array([1, 1, 1], dtype=np.float64),
node_names=["S", "P", "O"],
)
D = 8
II = 2
solver = GauSSolver(graph, D=D, lr=0.02)
result = solver.solve_modulo(II=II, R_cap=1.0, max_iters=1000, verbose=True)
print(f"\n Result: {result}")
# Verify dependencies
s = result.schedule
assert s[1] > s[0], f"P must be after S"
assert s[2] > s[1], f"O must be after P"
print("βœ“ Modulo scheduling (simple) test passed")
return True
def test_twill_comparison_simplified_fa():
"""Compare GauS vs Twill on simplified Flash Attention."""
print("\n" + "=" * 70)
print("TEST: GauS vs Twill β€” Simplified Flash Attention")
print("=" * 70)
graph = flash_attention_forward_simplified()
# Twill solution
print("--- Twill (ILP + SMT) ---")
t0 = time.time()
twill_result = twill_solve(
graph, max_I=5, enable_cost_normalization=False,
enable_memory_constraints=False, enable_warp_constraints=False,
verbose=False,
)
twill_time = time.time() - t0
if twill_result:
print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s")
# GauS solution
print("\n--- GauS (Differentiable) ---")
gaus_graph, name_to_idx = twill_graph_to_gaus(graph, D=10)
solver = GauSSolver(gaus_graph, D=10, lr=0.02)
gaus_result = solver.solve_modulo(
II=2, R_cap=1.0, max_iters=1500, legalize_every=200, verbose=True,
)
print(f"\n GauS: {gaus_result}")
# Compare
print(f"\n--- Comparison ---")
if twill_result:
print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s")
print(f" GauS: II={gaus_result.initiation_interval}, "
f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s")
return True
def test_twill_comparison_hopper():
"""Compare GauS vs Twill on Hopper FMHA forward."""
print("\n" + "=" * 70)
print("TEST: GauS vs Twill β€” Hopper FMHA Forward")
print("=" * 70)
graph = flash_attention_forward_hopper()
# Twill
print("--- Twill ---")
t0 = time.time()
twill_result = twill_solve(
graph, max_I=10, enable_cost_normalization=False,
enable_memory_constraints=False, enable_warp_constraints=False,
verbose=False,
)
twill_time = time.time() - t0
if twill_result:
print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s")
# GauS
print("\n--- GauS ---")
gaus_result = gaus_solve_twill_graph(
graph, target_II=4, D=20, max_iters=2000, verbose=True,
)
print(f"\n--- Comparison ---")
if twill_result:
print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s")
print(f" GauS: II={gaus_result.initiation_interval}, "
f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s")
return True
def test_scalability():
"""Test GauS scalability on larger graphs."""
print("\n" + "=" * 70)
print("TEST: Scalability β€” Random DAGs")
print("=" * 70)
for n_nodes in [50, 200, 1000]:
print(f"\n--- {n_nodes} nodes ---")
graph = generate_random_dag(
num_nodes=n_nodes,
edge_density=min(0.3, 10.0 / n_nodes), # Keep sparse for large graphs
max_weight=2,
num_back_edges=max(1, n_nodes // 20),
seed=42,
)
D = n_nodes + 10
II = max(2, n_nodes // 10)
solver = GauSSolver(graph, D=D, lr=0.01)
t0 = time.time()
result = solver.solve_modulo(
II=II, R_cap=float(n_nodes // 5),
max_iters=min(1000, n_nodes * 5),
legalize_every=200,
verbose=False,
)
elapsed = time.time() - t0
print(f" |V|={n_nodes}, |E|={len(graph.edges)}, D={D}, II={II}")
print(f" Time: {elapsed:.2f}s")
print(f" Feasible: {result.is_feasible}")
print(f" Violations: {result.num_violations}")
print(f" Schedule range: [{min(result.schedule.values())}, {max(result.schedule.values())}]")
print("\nβœ“ Scalability test passed")
return True
if __name__ == "__main__":
print("β•”" + "═" * 68 + "β•—")
print("β•‘" + " GauS Test Suite ".center(68) + "β•‘")
print("β•‘" + " Differentiable Scheduling via Gaussian Reparameterization ".center(68) + "β•‘")
print("β•‘" + " (arXiv:2602.20427) ".center(68) + "β•‘")
print("β•š" + "═" * 68 + "╝")
results = {}
start = time.time()
results["Gaussian CDF"] = test_gaussian_cdf()
results["ASAP/ALAP"] = test_asap_alap()
results["Regular Scheduling"] = test_regular_scheduling()
results["Modulo Scheduling (Simple)"] = test_modulo_scheduling_simple()
results["GauS vs Twill: Simplified FA"] = test_twill_comparison_simplified_fa()
results["GauS vs Twill: Hopper FMHA"] = test_twill_comparison_hopper()
results["Scalability"] = test_scalability()
elapsed = time.time() - start
print("\n" + "=" * 70)
print("TEST SUMMARY")
print("=" * 70)
for name, passed in results.items():
status = "βœ“ PASS" if passed else "βœ— FAIL"
print(f" {status} {name}")
print(f"\nTotal time: {elapsed:.2f}s")
print(f"Passed: {sum(results.values())}/{len(results)}")
sys.exit(0 if all(results.values()) else 1)