AshenNav commited on
Commit
1717684
Β·
verified Β·
1 Parent(s): 6cf8dee

Upload test_gaus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_gaus.py +308 -0
test_gaus.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test suite for GauS: Differentiable Scheduling via Gaussian Reparameterization.
4
+ Integrated with Twill's kernel descriptions for direct comparison.
5
+
6
+ Tests:
7
+ 1. Basic Gaussian reparameterization (P_i^d computation)
8
+ 2. Regular scheduling on simple DAGs
9
+ 3. Modulo scheduling on Twill's FMHA kernels (compare to Twill ILP+SMT)
10
+ 4. Scalability: 100-node, 1000-node random DAGs
11
+ """
12
+
13
+ import sys
14
+ import time
15
+ import numpy as np
16
+
17
+ sys.path.insert(0, '/app')
18
+
19
+ from twill.gaus_solver import (
20
+ GauSSolver, GausGraph, GausResult,
21
+ compute_asap, compute_alap, gaussian_cdf,
22
+ twill_graph_to_gaus, gaus_solve_twill_graph,
23
+ generate_random_dag,
24
+ )
25
+ from twill.kernels import (
26
+ flash_attention_forward_simplified,
27
+ flash_attention_forward_hopper,
28
+ flash_attention_forward_blackwell,
29
+ simple_gemm_pipeline,
30
+ )
31
+ from twill.twill_solver import twill_solve
32
+
33
+
34
+ def test_gaussian_cdf():
35
+ """Test basic Gaussian CDF computation."""
36
+ print("\n" + "=" * 70)
37
+ print("TEST: Gaussian CDF + P_i^d Computation")
38
+ print("=" * 70)
39
+
40
+ import torch
41
+
42
+ # CDF at 0 should be 0.5
43
+ assert abs(gaussian_cdf(torch.tensor(0.0)).item() - 0.5) < 1e-6
44
+ # CDF at large positive should be ~1
45
+ assert abs(gaussian_cdf(torch.tensor(5.0)).item() - 1.0) < 1e-4
46
+ # CDF at large negative should be ~0
47
+ assert abs(gaussian_cdf(torch.tensor(-5.0)).item()) < 1e-4
48
+
49
+ # Test P_i^d: single node at ΞΌ=2, Οƒ=0.5
50
+ graph = GausGraph(num_nodes=1, edges=[])
51
+ solver = GauSSolver(graph, D=6)
52
+
53
+ mu = torch.tensor([2.0])
54
+ sigma = torch.tensor([0.5])
55
+ P = solver._compute_P(mu, sigma)
56
+
57
+ print(f" ΞΌ=2.0, Οƒ=0.5, D=6")
58
+ print(f" P = {P[0].detach().numpy().round(4)}")
59
+ print(f" Sum P = {P[0].sum().item():.6f} (should be ~1.0)")
60
+ print(f" Argmax P = {P[0].argmax().item()} (should be 2)")
61
+
62
+ assert P[0].argmax().item() == 2, "Peak should be at ΞΌ=2"
63
+ assert abs(P[0].sum().item() - 1.0) < 0.01, "Probabilities should sum to ~1"
64
+
65
+ # As Οƒ β†’ 0, P should be a delta at round(ΞΌ)
66
+ sigma_small = torch.tensor([0.01])
67
+ P_sharp = solver._compute_P(mu, sigma_small)
68
+ print(f" Οƒ=0.01: P[2]={P_sharp[0, 2].item():.6f} (should be ~1.0)")
69
+ assert P_sharp[0, 2].item() > 0.99
70
+
71
+ print("βœ“ Gaussian CDF test passed")
72
+ return True
73
+
74
+
75
+ def test_asap_alap():
76
+ """Test ASAP/ALAP computation."""
77
+ print("\n" + "=" * 70)
78
+ print("TEST: ASAP / ALAP Computation")
79
+ print("=" * 70)
80
+
81
+ # Chain: 0 -> 1 -> 2
82
+ graph = GausGraph(
83
+ num_nodes=3,
84
+ edges=[(0, 1), (1, 2)],
85
+ node_names=["A", "B", "C"],
86
+ )
87
+
88
+ asap = compute_asap(graph)
89
+ alap = compute_alap(graph, D=5)
90
+
91
+ print(f" Chain A->B->C, D=5")
92
+ print(f" ASAP: {asap}") # Expected: [0, 1, 2]
93
+ print(f" ALAP: {alap}") # Expected: [2, 3, 4]
94
+
95
+ assert list(asap) == [0, 1, 2], f"ASAP wrong: {asap}"
96
+ assert list(alap) == [2, 3, 4], f"ALAP wrong: {alap}"
97
+
98
+ print("βœ“ ASAP/ALAP test passed")
99
+ return True
100
+
101
+
102
+ def test_regular_scheduling():
103
+ """Test regular (non-modulo) scheduling."""
104
+ print("\n" + "=" * 70)
105
+ print("TEST: Regular Scheduling (Formulation A)")
106
+ print("=" * 70)
107
+
108
+ # Diamond: 0 -> 1, 0 -> 2, 1 -> 3, 2 -> 3
109
+ graph = GausGraph(
110
+ num_nodes=4,
111
+ edges=[(0, 1), (0, 2), (1, 3), (2, 3)],
112
+ resource_weights=np.array([1, 1, 1, 1], dtype=np.float64),
113
+ node_names=["A", "B", "C", "D"],
114
+ )
115
+
116
+ solver = GauSSolver(graph, D=6, lr=0.05)
117
+ result = solver.solve_regular(max_iters=500, legalize_every=100, verbose=True)
118
+
119
+ print(f"\n Result: {result}")
120
+
121
+ # Verify dependencies
122
+ s = result.schedule
123
+ assert s[1] > s[0], f"B must be after A: {s[1]} > {s[0]}"
124
+ assert s[2] > s[0], f"C must be after A: {s[2]} > {s[0]}"
125
+ assert s[3] > s[1], f"D must be after B: {s[3]} > {s[1]}"
126
+ assert s[3] > s[2], f"D must be after C: {s[3]} > {s[2]}"
127
+ assert result.is_feasible, "Schedule should be feasible"
128
+
129
+ print("βœ“ Regular scheduling test passed")
130
+ return True
131
+
132
+
133
+ def test_modulo_scheduling_simple():
134
+ """Test modulo scheduling on a simple loop body."""
135
+ print("\n" + "=" * 70)
136
+ print("TEST: Modulo Scheduling (Formulation C) β€” Simple")
137
+ print("=" * 70)
138
+
139
+ # Simple loop: A -> B -> C, with C -> C loop-carried
140
+ graph = GausGraph(
141
+ num_nodes=3,
142
+ edges=[(0, 1), (1, 2)],
143
+ back_edges=[(2, 2, 1)], # C -> C with Ξ΄=1
144
+ resource_weights=np.array([1, 1, 1], dtype=np.float64),
145
+ node_names=["S", "P", "O"],
146
+ )
147
+
148
+ D = 8
149
+ II = 2
150
+ solver = GauSSolver(graph, D=D, lr=0.02)
151
+ result = solver.solve_modulo(II=II, R_cap=1.0, max_iters=1000, verbose=True)
152
+
153
+ print(f"\n Result: {result}")
154
+
155
+ # Verify dependencies
156
+ s = result.schedule
157
+ assert s[1] > s[0], f"P must be after S"
158
+ assert s[2] > s[1], f"O must be after P"
159
+
160
+ print("βœ“ Modulo scheduling (simple) test passed")
161
+ return True
162
+
163
+
164
+ def test_twill_comparison_simplified_fa():
165
+ """Compare GauS vs Twill on simplified Flash Attention."""
166
+ print("\n" + "=" * 70)
167
+ print("TEST: GauS vs Twill β€” Simplified Flash Attention")
168
+ print("=" * 70)
169
+
170
+ graph = flash_attention_forward_simplified()
171
+
172
+ # Twill solution
173
+ print("--- Twill (ILP + SMT) ---")
174
+ t0 = time.time()
175
+ twill_result = twill_solve(
176
+ graph, max_I=5, enable_cost_normalization=False,
177
+ enable_memory_constraints=False, enable_warp_constraints=False,
178
+ verbose=False,
179
+ )
180
+ twill_time = time.time() - t0
181
+ if twill_result:
182
+ print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s")
183
+
184
+ # GauS solution
185
+ print("\n--- GauS (Differentiable) ---")
186
+ gaus_graph, name_to_idx = twill_graph_to_gaus(graph, D=10)
187
+ solver = GauSSolver(gaus_graph, D=10, lr=0.02)
188
+
189
+ gaus_result = solver.solve_modulo(
190
+ II=2, R_cap=1.0, max_iters=1500, legalize_every=200, verbose=True,
191
+ )
192
+
193
+ print(f"\n GauS: {gaus_result}")
194
+
195
+ # Compare
196
+ print(f"\n--- Comparison ---")
197
+ if twill_result:
198
+ print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s")
199
+ print(f" GauS: II={gaus_result.initiation_interval}, "
200
+ f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s")
201
+
202
+ return True
203
+
204
+
205
+ def test_twill_comparison_hopper():
206
+ """Compare GauS vs Twill on Hopper FMHA forward."""
207
+ print("\n" + "=" * 70)
208
+ print("TEST: GauS vs Twill β€” Hopper FMHA Forward")
209
+ print("=" * 70)
210
+
211
+ graph = flash_attention_forward_hopper()
212
+
213
+ # Twill
214
+ print("--- Twill ---")
215
+ t0 = time.time()
216
+ twill_result = twill_solve(
217
+ graph, max_I=10, enable_cost_normalization=False,
218
+ enable_memory_constraints=False, enable_warp_constraints=False,
219
+ verbose=False,
220
+ )
221
+ twill_time = time.time() - t0
222
+ if twill_result:
223
+ print(f" Twill: I={twill_result.I}, schedule={twill_result.schedule}, time={twill_time:.3f}s")
224
+
225
+ # GauS
226
+ print("\n--- GauS ---")
227
+ gaus_result = gaus_solve_twill_graph(
228
+ graph, target_II=4, D=20, max_iters=2000, verbose=True,
229
+ )
230
+
231
+ print(f"\n--- Comparison ---")
232
+ if twill_result:
233
+ print(f" Twill: I={twill_result.I}, time={twill_time:.3f}s")
234
+ print(f" GauS: II={gaus_result.initiation_interval}, "
235
+ f"feasible={gaus_result.is_feasible}, time={gaus_result.solve_time_seconds:.3f}s")
236
+
237
+ return True
238
+
239
+
240
+ def test_scalability():
241
+ """Test GauS scalability on larger graphs."""
242
+ print("\n" + "=" * 70)
243
+ print("TEST: Scalability β€” Random DAGs")
244
+ print("=" * 70)
245
+
246
+ for n_nodes in [50, 200, 1000]:
247
+ print(f"\n--- {n_nodes} nodes ---")
248
+ graph = generate_random_dag(
249
+ num_nodes=n_nodes,
250
+ edge_density=min(0.3, 10.0 / n_nodes), # Keep sparse for large graphs
251
+ max_weight=2,
252
+ num_back_edges=max(1, n_nodes // 20),
253
+ seed=42,
254
+ )
255
+
256
+ D = n_nodes + 10
257
+ II = max(2, n_nodes // 10)
258
+
259
+ solver = GauSSolver(graph, D=D, lr=0.01)
260
+ t0 = time.time()
261
+ result = solver.solve_modulo(
262
+ II=II, R_cap=float(n_nodes // 5),
263
+ max_iters=min(1000, n_nodes * 5),
264
+ legalize_every=200,
265
+ verbose=False,
266
+ )
267
+ elapsed = time.time() - t0
268
+
269
+ print(f" |V|={n_nodes}, |E|={len(graph.edges)}, D={D}, II={II}")
270
+ print(f" Time: {elapsed:.2f}s")
271
+ print(f" Feasible: {result.is_feasible}")
272
+ print(f" Violations: {result.num_violations}")
273
+ print(f" Schedule range: [{min(result.schedule.values())}, {max(result.schedule.values())}]")
274
+
275
+ print("\nβœ“ Scalability test passed")
276
+ return True
277
+
278
+
279
+ if __name__ == "__main__":
280
+ print("β•”" + "═" * 68 + "β•—")
281
+ print("β•‘" + " GauS Test Suite ".center(68) + "β•‘")
282
+ print("β•‘" + " Differentiable Scheduling via Gaussian Reparameterization ".center(68) + "β•‘")
283
+ print("β•‘" + " (arXiv:2602.20427) ".center(68) + "β•‘")
284
+ print("β•š" + "═" * 68 + "╝")
285
+
286
+ results = {}
287
+ start = time.time()
288
+
289
+ results["Gaussian CDF"] = test_gaussian_cdf()
290
+ results["ASAP/ALAP"] = test_asap_alap()
291
+ results["Regular Scheduling"] = test_regular_scheduling()
292
+ results["Modulo Scheduling (Simple)"] = test_modulo_scheduling_simple()
293
+ results["GauS vs Twill: Simplified FA"] = test_twill_comparison_simplified_fa()
294
+ results["GauS vs Twill: Hopper FMHA"] = test_twill_comparison_hopper()
295
+ results["Scalability"] = test_scalability()
296
+
297
+ elapsed = time.time() - start
298
+
299
+ print("\n" + "=" * 70)
300
+ print("TEST SUMMARY")
301
+ print("=" * 70)
302
+ for name, passed in results.items():
303
+ status = "βœ“ PASS" if passed else "βœ— FAIL"
304
+ print(f" {status} {name}")
305
+ print(f"\nTotal time: {elapsed:.2f}s")
306
+ print(f"Passed: {sum(results.values())}/{len(results)}")
307
+
308
+ sys.exit(0 if all(results.values()) else 1)