#!/usr/bin/env python3 """ Test suite for GauS: Differentiable Scheduling via Gaussian Reparameterization. Integrated with Twill's kernel descriptions for direct comparison. Tests: 1. Basic Gaussian reparameterization (P_i^d computation) 2. Regular scheduling on simple DAGs 3. Modulo scheduling on Twill's FMHA kernels (compare to Twill ILP+SMT) 4. Scalability: 100-node, 1000-node random DAGs """ import sys import time import numpy as np sys.path.insert(0, '/app') from twill.gaus_solver import ( GauSSolver, GausGraph, GausResult, compute_asap, compute_alap, gaussian_cdf, twill_graph_to_gaus, gaus_solve_twill_graph, generate_random_dag, ) from twill.kernels import ( flash_attention_forward_simplified, flash_attention_forward_hopper, flash_attention_forward_blackwell, simple_gemm_pipeline, ) from twill.twill_solver import twill_solve def test_gaussian_cdf(): """Test basic Gaussian CDF computation.""" print("\n" + "=" * 70) print("TEST: Gaussian CDF + P_i^d Computation") print("=" * 70) import torch # CDF at 0 should be 0.5 assert abs(gaussian_cdf(torch.tensor(0.0)).item() - 0.5) < 1e-6 # CDF at large positive should be ~1 assert abs(gaussian_cdf(torch.tensor(5.0)).item() - 1.0) < 1e-4 # CDF at large negative should be ~0 assert abs(gaussian_cdf(torch.tensor(-5.0)).item()) < 1e-4 # Test P_i^d: single node at μ=2, σ=0.5 graph = GausGraph(num_nodes=1, edges=[]) solver = GauSSolver(graph, D=6) mu = torch.tensor([2.0]) sigma = torch.tensor([0.5]) P = solver._compute_P(mu, sigma) print(f" μ=2.0, σ=0.5, D=6") print(f" P = {P[0].detach().numpy().round(4)}") print(f" Sum P = {P[0].sum().item():.6f} (should be ~1.0)") print(f" Argmax P = {P[0].argmax().item()} (should be 2)") assert P[0].argmax().item() == 2, "Peak should be at μ=2" assert abs(P[0].sum().item() - 1.0) < 0.01, "Probabilities should sum to ~1" # As σ → 0, P should be a delta at round(μ) sigma_small = torch.tensor([0.01]) P_sharp = solver._compute_P(mu, sigma_small) print(f" σ=0.01: P[2]={P_sharp[0, 2].item():.6f} (should be ~1.0)") assert P_sharp[0, 2].item() > 0.99 print("✓ Gaussian CDF test passed") return True def test_asap_alap(): """Test ASAP/ALAP computation.""" print("\n" + "=" * 70) print("TEST: ASAP / ALAP Computation") print("=" * 70) # Chain: 0 -> 1 -> 2 graph = GausGraph( num_nodes=3, edges=[(0, 1), (1, 2)], node_names=["A", "B", "C"], ) asap = compute_asap(graph) alap = compute_alap(graph, D=5) print(f" Chain A->B->C, D=5") print(f" ASAP: {asap}") # Expected: [0, 1, 2] print(f" ALAP: {alap}") # Expected: [2, 3, 4] assert list(asap) == [0, 1, 2], f"ASAP wrong: {asap}" assert list(alap) == [2, 3, 4], f"ALAP wrong: {alap}" print("✓ ASAP/ALAP test passed") return True def test_regular_scheduling(): """Test regular (non-modulo) scheduling.""" print("\n" + "=" * 70) print("TEST: Regular Scheduling (Formulation A)") print("=" * 70) # Diamond: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3 graph = GausGraph( num_nodes=4, edges=[(0, 1), (0, 2), (1, 3), (2, 3)], resource_weights=np.array([1, 1, 1, 1], dtype=np.float64), node_names=["A", "B", "C", "D"], ) solver = GauSSolver(graph, D=6, lr=0.05) result = solver.solve_regular(max_iters=500, legalize_every=100, verbose=True) print(f"\n Result: {result}") # Verify dependencies s = result.schedule assert s[1] > s[0], f"B must be after A: {s[1]} > {s[0]}" assert s[2] > s[0], f"C must be after A: {s[2]} > {s[0]}" assert s[3] > s[1], f"D must be after B: {s[3]} > {s[1]}" assert s[3] > s[2], f"D must be after C: {s[3]} > {s[2]}" assert result.is_feasible, "Schedule should be feasible" print("✓ Regular scheduling test passed") return True def test_modulo_scheduling_simple(): """Test modulo scheduling on a simple loop body.""" print("\n" + "=" * 70) print("TEST: Modulo Scheduling (Formulation C) — Simple") print("=" * 70) # Simple loop: A -> B -> C, with C -> C loop-carried graph = GausGraph( num_nodes=3, edges=[(0, 1), (1, 2)], back_edges=[(2, 2, 1)], # C -> C with δ=1 resource_weights=np.array([1, 1, 1], dtype=np.float64), node_names=["S", "P", "O"], ) D = 8 II = 2 solver = GauSSolver(graph, D=D, lr=0.02) result = solver.solve_modulo(II=II, R_cap=1.0, max_iters=1000, verbose=True) print(f"\n Result: {result}") # Verify dependencies s = result.schedule assert s[1] > s[0], f"P must be after S" assert s[2] > s[1], f"O must be after P" print("✓ Modulo scheduling (simple) test passed") return True def test_twill_comparison_simplified_fa(): """Compare GauS vs Twill on simplified Flash Attention.""" print("\n" + "=" * 70) print("TEST: GauS vs Twill — Simplified Flash Attention") print("=" * 70) graph = flash_attention_forward_simplified() # Twill solution print("--- Twill (ILP + SMT) ---") t0 = time.time() twill_result = twill_solve( graph, max_I=5, enable_cost_normalization=False, enable_memory_constraints=False, enable_warp_constraints=False, verbose=False, ) twill_time = time.time() - t0 if twill_result: print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s") # GauS solution print("\n--- GauS (Differentiable) ---") gaus_graph, name_to_idx = twill_graph_to_gaus(graph, D=10) solver = GauSSolver(gaus_graph, D=10, lr=0.02) gaus_result = solver.solve_modulo( II=2, R_cap=1.0, max_iters=1500, legalize_every=200, verbose=True, ) print(f"\n GauS: {gaus_result}") # Compare print(f"\n--- Comparison ---") if twill_result: print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s") print(f" GauS: II={gaus_result.initiation_interval}, " f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s") return True def test_twill_comparison_hopper(): """Compare GauS vs Twill on Hopper FMHA forward.""" print("\n" + "=" * 70) print("TEST: GauS vs Twill — Hopper FMHA Forward") print("=" * 70) graph = flash_attention_forward_hopper() # Twill print("--- Twill ---") t0 = time.time() twill_result = twill_solve( graph, max_I=10, enable_cost_normalization=False, enable_memory_constraints=False, enable_warp_constraints=False, verbose=False, ) twill_time = time.time() - t0 if twill_result: print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s") # GauS print("\n--- GauS ---") gaus_result = gaus_solve_twill_graph( graph, target_II=4, D=20, max_iters=2000, verbose=True, ) print(f"\n--- Comparison ---") if twill_result: print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s") print(f" GauS: II={gaus_result.initiation_interval}, " f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s") return True def test_scalability(): """Test GauS scalability on larger graphs.""" print("\n" + "=" * 70) print("TEST: Scalability — Random DAGs") print("=" * 70) for n_nodes in [50, 200, 1000]: print(f"\n--- {n_nodes} nodes ---") graph = generate_random_dag( num_nodes=n_nodes, edge_density=min(0.3, 10.0 / n_nodes), # Keep sparse for large graphs max_weight=2, num_back_edges=max(1, n_nodes // 20), seed=42, ) D = n_nodes + 10 II = max(2, n_nodes // 10) solver = GauSSolver(graph, D=D, lr=0.01) t0 = time.time() result = solver.solve_modulo( II=II, R_cap=float(n_nodes // 5), max_iters=min(1000, n_nodes * 5), legalize_every=200, verbose=False, ) elapsed = time.time() - t0 print(f" |V|={n_nodes}, |E|={len(graph.edges)}, D={D}, II={II}") print(f" Time: {elapsed:.2f}s") print(f" Feasible: {result.is_feasible}") print(f" Violations: {result.num_violations}") print(f" Schedule range: [{min(result.schedule.values())}, {max(result.schedule.values())}]") print("\n✓ Scalability test passed") return True if __name__ == "__main__": print("╔" + "═" * 68 + "╗") print("║" + " GauS Test Suite ".center(68) + "║") print("║" + " Differentiable Scheduling via Gaussian Reparameterization ".center(68) + "║") print("║" + " (arXiv:2602.20427) ".center(68) + "║") print("╚" + "═" * 68 + "╝") results = {} start = time.time() results["Gaussian CDF"] = test_gaussian_cdf() results["ASAP/ALAP"] = test_asap_alap() results["Regular Scheduling"] = test_regular_scheduling() results["Modulo Scheduling (Simple)"] = test_modulo_scheduling_simple() results["GauS vs Twill: Simplified FA"] = test_twill_comparison_simplified_fa() results["GauS vs Twill: Hopper FMHA"] = test_twill_comparison_hopper() results["Scalability"] = test_scalability() elapsed = time.time() - start print("\n" + "=" * 70) print("TEST SUMMARY") print("=" * 70) for name, passed in results.items(): status = "✓ PASS" if passed else "✗ FAIL" print(f" {status} {name}") print(f"\nTotal time: {elapsed:.2f}s") print(f"Passed: {sum(results.values())}/{len(results)}") sys.exit(0 if all(results.values()) else 1)