| |
| """ |
| Test suite for GauS: Differentiable Scheduling via Gaussian Reparameterization. |
| Integrated with Twill's kernel descriptions for direct comparison. |
| |
| Tests: |
| 1. Basic Gaussian reparameterization (P_i^d computation) |
| 2. Regular scheduling on simple DAGs |
| 3. Modulo scheduling on Twill's FMHA kernels (compare to Twill ILP+SMT) |
| 4. Scalability: 100-node, 1000-node random DAGs |
| """ |
|
|
| import sys |
| import time |
| import numpy as np |
|
|
| sys.path.insert(0, '/app') |
|
|
| from twill.gaus_solver import ( |
| GauSSolver, GausGraph, GausResult, |
| compute_asap, compute_alap, gaussian_cdf, |
| twill_graph_to_gaus, gaus_solve_twill_graph, |
| generate_random_dag, |
| ) |
| from twill.kernels import ( |
| flash_attention_forward_simplified, |
| flash_attention_forward_hopper, |
| flash_attention_forward_blackwell, |
| simple_gemm_pipeline, |
| ) |
| from twill.twill_solver import twill_solve |
|
|
|
|
| def test_gaussian_cdf(): |
| """Test basic Gaussian CDF computation.""" |
| print("\n" + "=" * 70) |
| print("TEST: Gaussian CDF + P_i^d Computation") |
| print("=" * 70) |
| |
| import torch |
| |
| |
| assert abs(gaussian_cdf(torch.tensor(0.0)).item() - 0.5) < 1e-6 |
| |
| assert abs(gaussian_cdf(torch.tensor(5.0)).item() - 1.0) < 1e-4 |
| |
| assert abs(gaussian_cdf(torch.tensor(-5.0)).item()) < 1e-4 |
| |
| |
| graph = GausGraph(num_nodes=1, edges=[]) |
| solver = GauSSolver(graph, D=6) |
| |
| mu = torch.tensor([2.0]) |
| sigma = torch.tensor([0.5]) |
| P = solver._compute_P(mu, sigma) |
| |
| print(f" ΞΌ=2.0, Ο=0.5, D=6") |
| print(f" P = {P[0].detach().numpy().round(4)}") |
| print(f" Sum P = {P[0].sum().item():.6f} (should be ~1.0)") |
| print(f" Argmax P = {P[0].argmax().item()} (should be 2)") |
| |
| assert P[0].argmax().item() == 2, "Peak should be at ΞΌ=2" |
| assert abs(P[0].sum().item() - 1.0) < 0.01, "Probabilities should sum to ~1" |
| |
| |
| sigma_small = torch.tensor([0.01]) |
| P_sharp = solver._compute_P(mu, sigma_small) |
| print(f" Ο=0.01: P[2]={P_sharp[0, 2].item():.6f} (should be ~1.0)") |
| assert P_sharp[0, 2].item() > 0.99 |
| |
| print("β Gaussian CDF test passed") |
| return True |
|
|
|
|
| def test_asap_alap(): |
| """Test ASAP/ALAP computation.""" |
| print("\n" + "=" * 70) |
| print("TEST: ASAP / ALAP Computation") |
| print("=" * 70) |
| |
| |
| graph = GausGraph( |
| num_nodes=3, |
| edges=[(0, 1), (1, 2)], |
| node_names=["A", "B", "C"], |
| ) |
| |
| asap = compute_asap(graph) |
| alap = compute_alap(graph, D=5) |
| |
| print(f" Chain A->B->C, D=5") |
| print(f" ASAP: {asap}") |
| print(f" ALAP: {alap}") |
| |
| assert list(asap) == [0, 1, 2], f"ASAP wrong: {asap}" |
| assert list(alap) == [2, 3, 4], f"ALAP wrong: {alap}" |
| |
| print("β ASAP/ALAP test passed") |
| return True |
|
|
|
|
| def test_regular_scheduling(): |
| """Test regular (non-modulo) scheduling.""" |
| print("\n" + "=" * 70) |
| print("TEST: Regular Scheduling (Formulation A)") |
| print("=" * 70) |
| |
| |
| graph = GausGraph( |
| num_nodes=4, |
| edges=[(0, 1), (0, 2), (1, 3), (2, 3)], |
| resource_weights=np.array([1, 1, 1, 1], dtype=np.float64), |
| node_names=["A", "B", "C", "D"], |
| ) |
| |
| solver = GauSSolver(graph, D=6, lr=0.05) |
| result = solver.solve_regular(max_iters=500, legalize_every=100, verbose=True) |
| |
| print(f"\n Result: {result}") |
| |
| |
| s = result.schedule |
| assert s[1] > s[0], f"B must be after A: {s[1]} > {s[0]}" |
| assert s[2] > s[0], f"C must be after A: {s[2]} > {s[0]}" |
| assert s[3] > s[1], f"D must be after B: {s[3]} > {s[1]}" |
| assert s[3] > s[2], f"D must be after C: {s[3]} > {s[2]}" |
| assert result.is_feasible, "Schedule should be feasible" |
| |
| print("β Regular scheduling test passed") |
| return True |
|
|
|
|
| def test_modulo_scheduling_simple(): |
| """Test modulo scheduling on a simple loop body.""" |
| print("\n" + "=" * 70) |
| print("TEST: Modulo Scheduling (Formulation C) β Simple") |
| print("=" * 70) |
| |
| |
| graph = GausGraph( |
| num_nodes=3, |
| edges=[(0, 1), (1, 2)], |
| back_edges=[(2, 2, 1)], |
| resource_weights=np.array([1, 1, 1], dtype=np.float64), |
| node_names=["S", "P", "O"], |
| ) |
| |
| D = 8 |
| II = 2 |
| solver = GauSSolver(graph, D=D, lr=0.02) |
| result = solver.solve_modulo(II=II, R_cap=1.0, max_iters=1000, verbose=True) |
| |
| print(f"\n Result: {result}") |
| |
| |
| s = result.schedule |
| assert s[1] > s[0], f"P must be after S" |
| assert s[2] > s[1], f"O must be after P" |
| |
| print("β Modulo scheduling (simple) test passed") |
| return True |
|
|
|
|
| def test_twill_comparison_simplified_fa(): |
| """Compare GauS vs Twill on simplified Flash Attention.""" |
| print("\n" + "=" * 70) |
| print("TEST: GauS vs Twill β Simplified Flash Attention") |
| print("=" * 70) |
| |
| graph = flash_attention_forward_simplified() |
| |
| |
| print("--- Twill (ILP + SMT) ---") |
| t0 = time.time() |
| twill_result = twill_solve( |
| graph, max_I=5, enable_cost_normalization=False, |
| enable_memory_constraints=False, enable_warp_constraints=False, |
| verbose=False, |
| ) |
| twill_time = time.time() - t0 |
| if twill_result: |
| print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s") |
| |
| |
| print("\n--- GauS (Differentiable) ---") |
| gaus_graph, name_to_idx = twill_graph_to_gaus(graph, D=10) |
| solver = GauSSolver(gaus_graph, D=10, lr=0.02) |
| |
| gaus_result = solver.solve_modulo( |
| II=2, R_cap=1.0, max_iters=1500, legalize_every=200, verbose=True, |
| ) |
| |
| print(f"\n GauS: {gaus_result}") |
| |
| |
| print(f"\n--- Comparison ---") |
| if twill_result: |
| print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s") |
| print(f" GauS: II={gaus_result.initiation_interval}, " |
| f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s") |
| |
| return True |
|
|
|
|
| def test_twill_comparison_hopper(): |
| """Compare GauS vs Twill on Hopper FMHA forward.""" |
| print("\n" + "=" * 70) |
| print("TEST: GauS vs Twill β Hopper FMHA Forward") |
| print("=" * 70) |
| |
| graph = flash_attention_forward_hopper() |
| |
| |
| print("--- Twill ---") |
| t0 = time.time() |
| twill_result = twill_solve( |
| graph, max_I=10, enable_cost_normalization=False, |
| enable_memory_constraints=False, enable_warp_constraints=False, |
| verbose=False, |
| ) |
| twill_time = time.time() - t0 |
| if twill_result: |
| print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s") |
| |
| |
| print("\n--- GauS ---") |
| gaus_result = gaus_solve_twill_graph( |
| graph, target_II=4, D=20, max_iters=2000, verbose=True, |
| ) |
| |
| print(f"\n--- Comparison ---") |
| if twill_result: |
| print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s") |
| print(f" GauS: II={gaus_result.initiation_interval}, " |
| f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s") |
| |
| return True |
|
|
|
|
| def test_scalability(): |
| """Test GauS scalability on larger graphs.""" |
| print("\n" + "=" * 70) |
| print("TEST: Scalability β Random DAGs") |
| print("=" * 70) |
| |
| for n_nodes in [50, 200, 1000]: |
| print(f"\n--- {n_nodes} nodes ---") |
| graph = generate_random_dag( |
| num_nodes=n_nodes, |
| edge_density=min(0.3, 10.0 / n_nodes), |
| max_weight=2, |
| num_back_edges=max(1, n_nodes // 20), |
| seed=42, |
| ) |
| |
| D = n_nodes + 10 |
| II = max(2, n_nodes // 10) |
| |
| solver = GauSSolver(graph, D=D, lr=0.01) |
| t0 = time.time() |
| result = solver.solve_modulo( |
| II=II, R_cap=float(n_nodes // 5), |
| max_iters=min(1000, n_nodes * 5), |
| legalize_every=200, |
| verbose=False, |
| ) |
| elapsed = time.time() - t0 |
| |
| print(f" |V|={n_nodes}, |E|={len(graph.edges)}, D={D}, II={II}") |
| print(f" Time: {elapsed:.2f}s") |
| print(f" Feasible: {result.is_feasible}") |
| print(f" Violations: {result.num_violations}") |
| print(f" Schedule range: [{min(result.schedule.values())}, {max(result.schedule.values())}]") |
| |
| print("\nβ Scalability test passed") |
| return True |
|
|
|
|
| if __name__ == "__main__": |
| print("β" + "β" * 68 + "β") |
| print("β" + " GauS Test Suite ".center(68) + "β") |
| print("β" + " Differentiable Scheduling via Gaussian Reparameterization ".center(68) + "β") |
| print("β" + " (arXiv:2602.20427) ".center(68) + "β") |
| print("β" + "β" * 68 + "β") |
| |
| results = {} |
| start = time.time() |
| |
| results["Gaussian CDF"] = test_gaussian_cdf() |
| results["ASAP/ALAP"] = test_asap_alap() |
| results["Regular Scheduling"] = test_regular_scheduling() |
| results["Modulo Scheduling (Simple)"] = test_modulo_scheduling_simple() |
| results["GauS vs Twill: Simplified FA"] = test_twill_comparison_simplified_fa() |
| results["GauS vs Twill: Hopper FMHA"] = test_twill_comparison_hopper() |
| results["Scalability"] = test_scalability() |
| |
| elapsed = time.time() - start |
| |
| print("\n" + "=" * 70) |
| print("TEST SUMMARY") |
| print("=" * 70) |
| for name, passed in results.items(): |
| status = "β PASS" if passed else "β FAIL" |
| print(f" {status} {name}") |
| print(f"\nTotal time: {elapsed:.2f}s") |
| print(f"Passed: {sum(results.values())}/{len(results)}") |
| |
| sys.exit(0 if all(results.values()) else 1) |
|
|