AshenNav commited on
Commit
6cf8dee
·
verified ·
1 Parent(s): 933c21c

Upload twill/gaus_solver.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. twill/gaus_solver.py +888 -0
twill/gaus_solver.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization
3
+
4
+ Implementation of the paper by Yaohui Cai et al. (arXiv:2602.20427)
5
+
6
+ GauS models operator scheduling as a stochastic relaxation using Gaussian
7
+ distributions, optimized via gradient descent with an Augmented Lagrangian
8
+ Method (ALM). It supports:
9
+ - Formulation A: Latency-constrained resource + communication optimization
10
+ - Formulation B: Latency-constrained memory footprint optimization
11
+ - Formulation C: Modulo scheduling (pipelined) — directly comparable to Twill
12
+
13
+ Key advantages over ILP/SMT (Twill Phase 1+2):
14
+ - O(|V|) parameters vs O(D·|V|) for categorical approaches
15
+ - Scales to 10K+ operator graphs where ILP/SMT time out
16
+ - Exploits GPU parallelism for gradient computation
17
+
18
+ This module integrates with Twill's DependenceGraph as an alternative solver.
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.optim as optim
24
+ import numpy as np
25
+ import math
26
+ import time
27
+ from typing import Dict, List, Optional, Tuple, Set
28
+ from dataclasses import dataclass, field
29
+
30
+
31
+ # ============================================================
32
+ # Gaussian CDF helper
33
+ # ============================================================
34
+
35
+ def gaussian_cdf(x: torch.Tensor) -> torch.Tensor:
36
+ """Standard Gaussian CDF: Φ(x) = 0.5 * (1 + erf(x / sqrt(2)))"""
37
+ return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
38
+
39
+
40
+ # ============================================================
41
+ # GauS Graph representation (standalone, bridges to Twill)
42
+ # ============================================================
43
+
44
+ @dataclass
45
+ class GausGraph:
46
+ """Graph representation for GauS solver.
47
+
48
+ Attributes:
49
+ num_nodes: Number of operators |V|
50
+ edges: List of (src, dst) forward dependency edges
51
+ back_edges: List of (src, dst, k) loop-carried back-edges with iteration distance k
52
+ resource_weights: Per-node resource demand w_i, shape [|V|]
53
+ memory_weights: Per-node storage bitwidth b_i, shape [|V|]
54
+ successors: Dict mapping node -> list of successor nodes
55
+ predecessors: Dict mapping node -> list of predecessor nodes
56
+ node_names: Optional names for nodes
57
+ """
58
+ num_nodes: int
59
+ edges: List[Tuple[int, int]] # (src, dst) forward edges
60
+ back_edges: List[Tuple[int, int, int]] = field(default_factory=list) # (src, dst, k)
61
+ resource_weights: Optional[np.ndarray] = None # [|V|]
62
+ memory_weights: Optional[np.ndarray] = None # [|V|]
63
+ node_names: Optional[List[str]] = None
64
+
65
+ def __post_init__(self):
66
+ # Build adjacency
67
+ self.successors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)}
68
+ self.predecessors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)}
69
+ for (u, v) in self.edges:
70
+ self.successors[u].append(v)
71
+ self.predecessors[v].append(u)
72
+ # Defaults
73
+ if self.resource_weights is None:
74
+ self.resource_weights = np.ones(self.num_nodes)
75
+ if self.memory_weights is None:
76
+ self.memory_weights = np.ones(self.num_nodes)
77
+ if self.node_names is None:
78
+ self.node_names = [f"v{i}" for i in range(self.num_nodes)]
79
+
80
+ def topological_sort(self) -> List[int]:
81
+ """Kahn's algorithm for topological sort."""
82
+ in_degree = [0] * self.num_nodes
83
+ for (u, v) in self.edges:
84
+ in_degree[v] += 1
85
+ queue = [i for i in range(self.num_nodes) if in_degree[i] == 0]
86
+ order = []
87
+ while queue:
88
+ node = queue.pop(0)
89
+ order.append(node)
90
+ for succ in self.successors[node]:
91
+ in_degree[succ] -= 1
92
+ if in_degree[succ] == 0:
93
+ queue.append(succ)
94
+ return order
95
+
96
+
97
+ # ============================================================
98
+ # ASAP / ALAP computation
99
+ # ============================================================
100
+
101
+ def compute_asap(graph: GausGraph) -> np.ndarray:
102
+ """Compute As-Soon-As-Possible schedule (longest path from sources)."""
103
+ asap = np.zeros(graph.num_nodes, dtype=np.float64)
104
+ for v in graph.topological_sort():
105
+ for pred in graph.predecessors[v]:
106
+ asap[v] = max(asap[v], asap[pred] + 1)
107
+ return asap
108
+
109
+
110
+ def compute_alap(graph: GausGraph, D: int) -> np.ndarray:
111
+ """Compute As-Late-As-Possible schedule (latest feasible time given depth D)."""
112
+ alap = np.full(graph.num_nodes, D - 1, dtype=np.float64)
113
+ for v in reversed(graph.topological_sort()):
114
+ for succ in graph.successors[v]:
115
+ alap[v] = min(alap[v], alap[succ] - 1)
116
+ return alap
117
+
118
+
119
+ # ============================================================
120
+ # GauS Solver Result
121
+ # ============================================================
122
+
123
+ @dataclass
124
+ class GausResult:
125
+ """Result from GauS solver.
126
+
127
+ Attributes:
128
+ schedule: Dict mapping node index -> scheduled time step
129
+ initiation_interval: II (for modulo scheduling, else None)
130
+ objective_value: Final objective value
131
+ num_violations: Number of constraint violations in final schedule
132
+ solve_time_seconds: Wall-clock solve time
133
+ iterations: Number of optimization iterations
134
+ loss_history: List of total loss values per iteration
135
+ node_names: Optional node names for display
136
+ """
137
+ schedule: Dict[int, int]
138
+ initiation_interval: Optional[int] = None
139
+ objective_value: float = 0.0
140
+ num_violations: int = 0
141
+ solve_time_seconds: float = 0.0
142
+ iterations: int = 0
143
+ loss_history: List[float] = field(default_factory=list)
144
+ node_names: Optional[List[str]] = None
145
+
146
+ @property
147
+ def is_feasible(self) -> bool:
148
+ return self.num_violations == 0
149
+
150
+ def named_schedule(self) -> Dict[str, int]:
151
+ if self.node_names:
152
+ return {self.node_names[i]: t for i, t in self.schedule.items()}
153
+ return {f"v{i}": t for i, t in self.schedule.items()}
154
+
155
+ def __repr__(self):
156
+ sched_str = self.named_schedule()
157
+ return (
158
+ f"GausResult(\n"
159
+ f" schedule={sched_str}\n"
160
+ f" II={self.initiation_interval}\n"
161
+ f" objective={self.objective_value:.4f}\n"
162
+ f" violations={self.num_violations}\n"
163
+ f" feasible={self.is_feasible}\n"
164
+ f" solve_time={self.solve_time_seconds:.2f}s\n"
165
+ f" iterations={self.iterations}\n"
166
+ f")"
167
+ )
168
+
169
+
170
+ # ============================================================
171
+ # Core GauS Solver
172
+ # ============================================================
173
+
174
+ class GauSSolver:
175
+ """Differentiable scheduling solver using Gaussian reparameterization.
176
+
177
+ Implements Algorithm 1 from the paper with all three formulations.
178
+
179
+ Usage:
180
+ solver = GauSSolver(graph, D=10)
181
+ result = solver.solve_regular() # Formulation A
182
+ result = solver.solve_modulo(II=3) # Formulation C (like Twill)
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ graph: GausGraph,
188
+ D: int,
189
+ kappa: float = 1.0 / 6.0,
190
+ rho: float = 1e-4,
191
+ tau: float = 1e-2,
192
+ lr: float = 1e-2,
193
+ device: str = "cpu",
194
+ ):
195
+ """
196
+ Args:
197
+ graph: The scheduling graph
198
+ D: Maximum schedule depth (latency bound)
199
+ kappa: Std init scale factor (σ = κ · (ALAP - ASAP))
200
+ rho: ALM penalty coefficient
201
+ tau: LogSumExp temperature
202
+ lr: Adam learning rate
203
+ device: torch device
204
+ """
205
+ self.graph = graph
206
+ self.D = D
207
+ self.kappa = kappa
208
+ self.rho = rho
209
+ self.tau = tau
210
+ self.lr = lr
211
+ self.device = device
212
+ self.N = graph.num_nodes
213
+
214
+ # Precompute ASAP/ALAP
215
+ self.s_asap = compute_asap(graph)
216
+ self.s_alap = compute_alap(graph, D)
217
+
218
+ # Precompute edge tensors for vectorized loss computation
219
+ if graph.edges:
220
+ self.edge_src = torch.tensor([e[0] for e in graph.edges], dtype=torch.long, device=device)
221
+ self.edge_dst = torch.tensor([e[1] for e in graph.edges], dtype=torch.long, device=device)
222
+ else:
223
+ self.edge_src = torch.tensor([], dtype=torch.long, device=device)
224
+ self.edge_dst = torch.tensor([], dtype=torch.long, device=device)
225
+
226
+ # Resource weights
227
+ self.w = torch.tensor(graph.resource_weights, dtype=torch.float32, device=device)
228
+ self.b = torch.tensor(graph.memory_weights, dtype=torch.float32, device=device)
229
+
230
+ def _init_params(self, mu_init: Optional[np.ndarray] = None) -> Tuple[nn.Parameter, nn.Parameter]:
231
+ """Initialize μ and σ parameters (Section 3.1)."""
232
+ if mu_init is not None:
233
+ mu = torch.tensor(mu_init, dtype=torch.float32, device=self.device)
234
+ else:
235
+ # μ₀ = (ASAP + ALAP) / 2
236
+ mu = torch.tensor(
237
+ (self.s_asap + self.s_alap) / 2.0,
238
+ dtype=torch.float32, device=self.device,
239
+ )
240
+
241
+ # σ = κ · (ALAP - ASAP), with minimum to avoid zero
242
+ freedom = self.s_alap - self.s_asap
243
+ sigma_init = self.kappa * np.maximum(freedom, 0.5)
244
+ sigma = torch.tensor(sigma_init, dtype=torch.float32, device=self.device)
245
+
246
+ return nn.Parameter(mu), nn.Parameter(sigma)
247
+
248
+ def _compute_P(self, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
249
+ """Compute P_i^d for all (i, d). Equation 7b.
250
+
251
+ P_i^d = Φ((d+0.5-μ_i)/σ_i) - Φ((d-0.5-μ_i)/σ_i)
252
+
253
+ Returns: shape [N, D]
254
+ """
255
+ sigma_safe = sigma.abs() + 1e-6 # [N]
256
+ d = torch.arange(self.D, dtype=torch.float32, device=self.device) # [D]
257
+
258
+ # Broadcasting: mu [N,1], sigma [N,1], d [1,D]
259
+ mu_exp = mu.unsqueeze(1) # [N, 1]
260
+ sig_exp = sigma_safe.unsqueeze(1) # [N, 1]
261
+ d_exp = d.unsqueeze(0) # [1, D]
262
+
263
+ # Boundary handling (footnote 2 in paper):
264
+ # d=0: lower bound is -inf; d=D-1: upper bound is +inf
265
+ upper = (d_exp + 0.5 - mu_exp) / sig_exp # [N, D]
266
+ lower = (d_exp - 0.5 - mu_exp) / sig_exp # [N, D]
267
+
268
+ P = gaussian_cdf(upper) - gaussian_cdf(lower) # [N, D]
269
+
270
+ # Clamp to avoid numerical issues
271
+ P = P.clamp(min=1e-10, max=1.0)
272
+
273
+ return P
274
+
275
+ # ============================================================
276
+ # Loss functions (Section 3.2 + Appendix B)
277
+ # ============================================================
278
+
279
+ def _loss_dependency(self, P: torch.Tensor) -> torch.Tensor:
280
+ """Expected dependency violations (Equation 8).
281
+
282
+ V_dep = Σ_{(i,j)∈E} Σ_{d_i=1}^{D-1} Σ_{d_j=0}^{d_i-1} P_i^{d_i} · P_j^{d_j}
283
+
284
+ Vectorized: for each edge (i,j), V = P_i · cumsum(P_j) shifted
285
+ """
286
+ if len(self.edge_src) == 0:
287
+ return torch.tensor(0.0, device=self.device)
288
+
289
+ P_src = P[self.edge_src] # [|E|, D]
290
+ P_dst = P[self.edge_dst] # [|E|, D]
291
+
292
+ # CDF_j(d) = cumsum of P_j up to d
293
+ cdf_dst = torch.cumsum(P_dst, dim=1) # [|E|, D]
294
+
295
+ # Shifted: CDF_j(d_i - 1) = probability consumer is before d_i
296
+ cdf_shifted = torch.cat([
297
+ torch.zeros(len(self.edge_src), 1, device=self.device),
298
+ cdf_dst[:, :-1]
299
+ ], dim=1) # [|E|, D]
300
+
301
+ # V_dep per edge = Σ_{d_i} P_i^{d_i} * CDF_j(d_i - 1)
302
+ V_dep = torch.sum(P_src * cdf_shifted, dim=1) # [|E|]
303
+
304
+ return V_dep.sum()
305
+
306
+ def _loss_resource(self, P: torch.Tensor, R_cap: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ """Expected resource usage + violations (Equations 18-20).
308
+
309
+ Res(d) = Σ_i w_i · P_i^d
310
+ L_res = τ · log(Σ_d exp(Res(d)/τ)) (LogSumExp smooth-max)
311
+ V_res = Σ_d ReLU(Res(d) - R) (violations)
312
+
313
+ Returns: (L_res, V_res)
314
+ """
315
+ # Res(d) shape [D]
316
+ res_d = torch.matmul(self.w, P) # [D]
317
+
318
+ # LogSumExp smooth-max
319
+ L_res = self.tau * torch.logsumexp(res_d / self.tau, dim=0)
320
+
321
+ # Violations
322
+ if R_cap is not None:
323
+ V_res = torch.sum(torch.relu(res_d - R_cap))
324
+ else:
325
+ V_res = torch.tensor(0.0, device=self.device)
326
+
327
+ return L_res, V_res
328
+
329
+ def _loss_communication(self, P: torch.Tensor) -> torch.Tensor:
330
+ """Expected communication overhead (Equation 17).
331
+
332
+ L_com = Σ_{(i,j)∈E} Σ_{d_i} Σ_{d_j≥d_i} P_i^{d_i} · P_j^{d_j} · (d_j - d_i)
333
+
334
+ Vectorized using expected value: E[d_j - d_i] = E[d_j] - E[d_i] (for valid pairs)
335
+ Simplified to: Σ_{(i,j)∈E} (μ_j - μ_i) when σ→0
336
+ More precise: use the full double sum with broadcasting
337
+ """
338
+ if len(self.edge_src) == 0:
339
+ return torch.tensor(0.0, device=self.device)
340
+
341
+ P_src = P[self.edge_src] # [|E|, D]
342
+ P_dst = P[self.edge_dst] # [|E|, D]
343
+
344
+ d = torch.arange(self.D, dtype=torch.float32, device=self.device) # [D]
345
+
346
+ # E[d_j | valid] = Σ_{d_j} d_j · P_j^{d_j}
347
+ # E[d_i | valid] = Σ_{d_i} d_i · P_i^{d_i}
348
+ # Approximate: L_com ≈ Σ edges (E[d_j] - E[d_i])
349
+ E_src = torch.sum(P_src * d.unsqueeze(0), dim=1) # [|E|]
350
+ E_dst = torch.sum(P_dst * d.unsqueeze(0), dim=1) # [|E|]
351
+
352
+ L_com = torch.sum(E_dst - E_src)
353
+ return L_com
354
+
355
+ def _loss_modulo_resource(
356
+ self, P: torch.Tensor, II: int, R_cap: float
357
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
358
+ """Modulo resource usage + violations (Equations 21-22, Appendix B.3).
359
+
360
+ P_mod[i,t] = Σ_k P_i^{t + k·II} (wrap probabilities into II slots)
361
+ MRes(t) = Σ_i w_i · P_mod[i,t]
362
+ V_mres = Σ_t ReLU(MRes(t) - R_cap)
363
+ """
364
+ # Wrap probabilities into modulo II slots
365
+ P_mod = torch.zeros(self.N, II, device=self.device)
366
+ for k in range((self.D + II - 1) // II):
367
+ for t in range(II):
368
+ d = t + k * II
369
+ if d < self.D:
370
+ P_mod[:, t] += P[:, d]
371
+
372
+ # MRes(t) = Σ_i w_i · P_mod[i,t], shape [II]
373
+ mres_t = torch.matmul(self.w, P_mod)
374
+
375
+ # LogSumExp objective
376
+ L_mres = self.tau * torch.logsumexp(mres_t / self.tau, dim=0)
377
+
378
+ # Violations
379
+ V_mres = torch.sum(torch.relu(mres_t - R_cap))
380
+
381
+ return L_mres, V_mres
382
+
383
+ def _loss_recurrence(self, P: torch.Tensor, II: int) -> torch.Tensor:
384
+ """Expected recurrence violations for loop-carried deps (Equation 24).
385
+
386
+ V_rec = Σ_{(v_i,v_j,k)∈E_B} Σ_{d_i=1}^{D-1} Σ_{d_j=d_i+k·II+1}^{D-1} P_i^{d_i} · P_j^{d_j}
387
+
388
+ A violation occurs when s_j > s_i + k·II (consumer too late).
389
+ """
390
+ if not self.graph.back_edges:
391
+ return torch.tensor(0.0, device=self.device)
392
+
393
+ V_rec = torch.tensor(0.0, device=self.device)
394
+
395
+ for (vi, vj, k) in self.graph.back_edges:
396
+ # Self-loops: constraint is II ≥ Lat(v), always satisfied if II was
397
+ # chosen correctly. Skip to avoid confusing the optimizer.
398
+ if vi == vj:
399
+ continue
400
+
401
+ Pi = P[vi] # [D]
402
+ Pj = P[vj] # [D]
403
+
404
+ # Reverse cumsum of Pj: rcs[d] = Σ_{d_j >= d} P_j^{d_j}
405
+ rcs_j = torch.flip(torch.cumsum(torch.flip(Pj, [0]), dim=0), [0])
406
+ # Violation: d_j > d_i + k*II → rcs_j[d_i + k*II + 1]
407
+
408
+ for d_i in range(self.D):
409
+ threshold = d_i + k * II + 1
410
+ if threshold < self.D:
411
+ V_rec = V_rec + Pi[d_i] * rcs_j[threshold]
412
+
413
+ return V_rec
414
+
415
+ def _loss_memory(self, P: torch.Tensor) -> torch.Tensor:
416
+ """Expected peak memory footprint (Equations 9-11).
417
+
418
+ Active(v_i, d) = P(X_i ≤ d) · P(max_{j∈succ(i)} X_j > d)
419
+ Mem(d) = Σ_i b_i · Active(v_i, d)
420
+ L_mem = τ · log(Σ_d exp(Mem(d)/τ))
421
+ """
422
+ d = torch.arange(self.D, dtype=torch.float32, device=self.device)
423
+
424
+ # CDF_i(d) = Σ_{d'≤d} P_i^{d'}
425
+ cdf = torch.cumsum(P, dim=1) # [N, D]
426
+
427
+ # For each node, compute P(all successors finished by d)
428
+ # = Π_{j∈succ(i)} CDF_j(d)
429
+ all_succ_done = torch.ones(self.N, self.D, device=self.device)
430
+ for i in range(self.N):
431
+ for j in self.graph.successors[i]:
432
+ all_succ_done[i] *= cdf[j]
433
+
434
+ # Active(i, d) = CDF_i(d) · (1 - Π_{j∈succ(i)} CDF_j(d))
435
+ # (started and at least one successor hasn't finished)
436
+ active = cdf * (1.0 - all_succ_done) # [N, D]
437
+
438
+ # Mem(d) = Σ_i b_i · Active(i, d)
439
+ mem_d = torch.matmul(self.b, active) # [D]
440
+
441
+ L_mem = self.tau * torch.logsumexp(mem_d / self.tau, dim=0)
442
+ return L_mem
443
+
444
+ # ============================================================
445
+ # Legalization (Appendix A)
446
+ # ============================================================
447
+
448
+ def _legalize_regular(self, s: np.ndarray) -> np.ndarray:
449
+ """Algorithm 2: Regular schedule legalization via topological pass."""
450
+ s_new = np.clip(s, self.s_asap, self.s_alap).astype(int)
451
+
452
+ for v in self.graph.topological_sort():
453
+ preds = self.graph.predecessors[v]
454
+ if preds:
455
+ t_req = max(s_new[p] for p in preds) + 1
456
+ s_new[v] = max(s_new[v], t_req)
457
+
458
+ return s_new
459
+
460
+ def _legalize_modulo(self, s: np.ndarray, II: int) -> np.ndarray:
461
+ """Algorithm 3: Modulo schedule legalization via fixed-point iteration."""
462
+ s_new = s.copy().astype(int)
463
+ topo = self.graph.topological_sort()
464
+
465
+ for _ in range(self.N):
466
+ changed = False
467
+ for v in topo:
468
+ # Forward dependency requirement
469
+ preds = self.graph.predecessors[v]
470
+ t_min = max((s_new[u] for u in preds), default=-1) + 1
471
+
472
+ # Back-edge requirement: s_i + k·II ≥ s_j + Lat(v_j)
473
+ # => s_j ≤ s_i + k·II - Lat(v_j) (already scheduled)
474
+ # For the current node v as consumer (vj), the constraint is:
475
+ # s_v ≥ s_producer - k·II + 1
476
+ # Skip self-loops (automatically satisfied by modulo structure)
477
+ t_back = 0
478
+ for (vi, vj, k) in self.graph.back_edges:
479
+ if vj == v and vi != v: # v is the consumer, skip self-loops
480
+ t_back = max(t_back, s_new[vi] - k * II + 1)
481
+
482
+ t_req = max(t_min, t_back)
483
+ if t_req > s_new[v]:
484
+ s_new[v] = t_req
485
+ changed = True
486
+
487
+ if not changed:
488
+ break
489
+
490
+ return s_new
491
+
492
+ def _count_violations(self, s: np.ndarray, II: Optional[int] = None) -> int:
493
+ """Count constraint violations in a discrete schedule."""
494
+ count = 0
495
+ # Forward dependency violations
496
+ for (u, v) in self.graph.edges:
497
+ if s[v] <= s[u]:
498
+ count += 1
499
+ # Recurrence violations
500
+ if II is not None:
501
+ for (vi, vj, k) in self.graph.back_edges:
502
+ if s[vi] + k * II < s[vj]:
503
+ count += 1
504
+ return count
505
+
506
+ # ============================================================
507
+ # Main solve methods
508
+ # ============================================================
509
+
510
+ def solve_regular(
511
+ self,
512
+ max_iters: int = 2000,
513
+ legalize_every: int = 200,
514
+ alpha_com: float = 0.1,
515
+ R_cap: Optional[float] = None,
516
+ verbose: bool = True,
517
+ ) -> GausResult:
518
+ """Solve Formulation A: latency-constrained resource + communication optimization.
519
+
520
+ Args:
521
+ max_iters: Maximum optimization iterations
522
+ legalize_every: Legalize and warm-restart every N iterations
523
+ alpha_com: Weight for communication objective
524
+ R_cap: Resource capacity limit (None = no hard cap)
525
+ verbose: Print progress
526
+ """
527
+ return self._solve(
528
+ formulation="A",
529
+ max_iters=max_iters,
530
+ legalize_every=legalize_every,
531
+ alpha_com=alpha_com,
532
+ R_cap=R_cap,
533
+ II=None,
534
+ verbose=verbose,
535
+ )
536
+
537
+ def solve_modulo(
538
+ self,
539
+ II: int,
540
+ R_cap: float = 1.0,
541
+ max_iters: int = 3000,
542
+ legalize_every: int = 300,
543
+ verbose: bool = True,
544
+ ) -> GausResult:
545
+ """Solve Formulation C: modulo scheduling (pipelined).
546
+
547
+ This is the formulation directly comparable to Twill's domain.
548
+
549
+ Args:
550
+ II: Initiation interval (target pipeline rate)
551
+ R_cap: Per-slot resource capacity in modulo reservation table
552
+ max_iters: Maximum optimization iterations
553
+ legalize_every: Legalize and warm-restart every N iterations
554
+ verbose: Print progress
555
+ """
556
+ return self._solve(
557
+ formulation="C",
558
+ max_iters=max_iters,
559
+ legalize_every=legalize_every,
560
+ R_cap=R_cap,
561
+ II=II,
562
+ verbose=verbose,
563
+ )
564
+
565
+ def _solve(
566
+ self,
567
+ formulation: str,
568
+ max_iters: int,
569
+ legalize_every: int,
570
+ alpha_com: float = 0.1,
571
+ R_cap: Optional[float] = None,
572
+ II: Optional[int] = None,
573
+ verbose: bool = True,
574
+ ) -> GausResult:
575
+ """Core optimization loop implementing Algorithm 1."""
576
+ start_time = time.time()
577
+
578
+ if verbose:
579
+ print(f"GauS Solver — Formulation {formulation}")
580
+ print(f" |V|={self.N}, |E|={len(self.graph.edges)}, D={self.D}")
581
+ if II: print(f" II={II}, R_cap={R_cap}")
582
+ print(f" ASAP: {self.s_asap}")
583
+ print(f" ALAP: {self.s_alap}")
584
+
585
+ # Initialize parameters
586
+ mu, sigma = self._init_params()
587
+ optimizer = optim.Adam([mu, sigma], lr=self.lr)
588
+
589
+ # Lagrange multipliers (ALM)
590
+ lambda_dep = torch.tensor(1e-6, device=self.device)
591
+ lambda_res = torch.tensor(1e-6, device=self.device)
592
+ lambda_mres = torch.tensor(1e-6, device=self.device)
593
+ lambda_rec = torch.tensor(1e-6, device=self.device)
594
+
595
+ loss_history = []
596
+ best_schedule = None
597
+ best_objective = float('inf')
598
+ best_violations = float('inf')
599
+
600
+ for it in range(max_iters):
601
+ optimizer.zero_grad()
602
+
603
+ # Compute P_i^d
604
+ P = self._compute_P(mu, sigma)
605
+
606
+ # Primary objectives
607
+ if formulation == "A":
608
+ L_res, V_res = self._loss_resource(P, R_cap)
609
+ L_com = self._loss_communication(P)
610
+ L_primary = L_res + alpha_com * L_com
611
+ elif formulation == "B":
612
+ L_primary = self._loss_memory(P)
613
+ V_res = torch.tensor(0.0, device=self.device)
614
+ elif formulation == "C":
615
+ L_mres, V_mres = self._loss_modulo_resource(P, II, R_cap)
616
+ # Add compactness term: encourage short schedules
617
+ L_com = self._loss_communication(P)
618
+ L_primary = L_mres + 0.1 * L_com
619
+ V_res = torch.tensor(0.0, device=self.device)
620
+ else:
621
+ raise ValueError(f"Unknown formulation: {formulation}")
622
+
623
+ # Constraint violations
624
+ V_dep = self._loss_dependency(P)
625
+
626
+ if formulation == "C":
627
+ V_rec = self._loss_recurrence(P, II)
628
+ else:
629
+ V_rec = torch.tensor(0.0, device=self.device)
630
+ V_mres = torch.tensor(0.0, device=self.device)
631
+
632
+ # Augmented Lagrangian
633
+ L_total = L_primary
634
+ L_total = L_total + lambda_dep * V_dep + (self.rho / 2) * V_dep ** 2
635
+ L_total = L_total + lambda_res * V_res + (self.rho / 2) * V_res ** 2
636
+ if formulation == "C":
637
+ L_total = L_total + lambda_mres * V_mres + (self.rho / 2) * V_mres ** 2
638
+ L_total = L_total + lambda_rec * V_rec + (self.rho / 2) * V_rec ** 2
639
+
640
+ # Backward + optimize
641
+ L_total.backward()
642
+ optimizer.step()
643
+
644
+ # Update Lagrange multipliers
645
+ with torch.no_grad():
646
+ lambda_dep = lambda_dep + self.rho * V_dep.detach()
647
+ lambda_res = lambda_res + self.rho * V_res.detach()
648
+ if formulation == "C":
649
+ lambda_mres = lambda_mres + self.rho * V_mres.detach()
650
+ lambda_rec = lambda_rec + self.rho * V_rec.detach()
651
+
652
+ loss_history.append(L_total.item())
653
+
654
+ # Periodic legalization + warm restart
655
+ if (it + 1) % legalize_every == 0 or it == max_iters - 1:
656
+ with torch.no_grad():
657
+ s_rounded = torch.round(mu).cpu().numpy().astype(int)
658
+
659
+ if formulation == "C" and II:
660
+ s_legal = self._legalize_modulo(s_rounded, II)
661
+ else:
662
+ s_legal = self._legalize_regular(s_rounded)
663
+
664
+ violations = self._count_violations(s_legal, II)
665
+
666
+ # Track best
667
+ obj_val = L_primary.item()
668
+ if violations < best_violations or (violations == best_violations and obj_val < best_objective):
669
+ best_violations = violations
670
+ best_objective = obj_val
671
+ best_schedule = s_legal.copy()
672
+
673
+ if verbose and ((it + 1) % (legalize_every) == 0):
674
+ sigma_mean = sigma.abs().mean().item()
675
+ print(f" iter {it+1:5d}: L={L_total.item():.4f}, "
676
+ f"V_dep={V_dep.item():.4f}, "
677
+ f"σ_mean={sigma_mean:.4f}, "
678
+ f"violations={violations}")
679
+
680
+ # Warm restart: re-initialize μ from legalized schedule
681
+ if violations > 0:
682
+ mu.data = torch.tensor(
683
+ s_legal.astype(np.float64),
684
+ dtype=torch.float32, device=self.device
685
+ )
686
+
687
+ # Final extraction
688
+ with torch.no_grad():
689
+ s_final = torch.round(mu).cpu().numpy().astype(int)
690
+ if formulation == "C" and II:
691
+ s_final = self._legalize_modulo(s_final, II)
692
+ else:
693
+ s_final = self._legalize_regular(s_final)
694
+
695
+ final_violations = self._count_violations(s_final, II)
696
+
697
+ # Use best if final has more violations
698
+ if best_schedule is not None and final_violations > best_violations:
699
+ s_final = best_schedule
700
+ final_violations = best_violations
701
+
702
+ solve_time = time.time() - start_time
703
+
704
+ schedule = {i: int(s_final[i]) for i in range(self.N)}
705
+
706
+ if verbose:
707
+ print(f"\n DONE in {solve_time:.2f}s, {max_iters} iterations")
708
+ print(f" Final schedule: {self._format_schedule(schedule)}")
709
+ print(f" Violations: {final_violations}")
710
+ if II: print(f" II: {II}")
711
+
712
+ return GausResult(
713
+ schedule=schedule,
714
+ initiation_interval=II,
715
+ objective_value=best_objective,
716
+ num_violations=final_violations,
717
+ solve_time_seconds=solve_time,
718
+ iterations=max_iters,
719
+ loss_history=loss_history,
720
+ node_names=self.graph.node_names,
721
+ )
722
+
723
+ def _format_schedule(self, schedule: Dict[int, int]) -> str:
724
+ names = self.graph.node_names
725
+ return ", ".join(f"{names[i]}@{t}" for i, t in sorted(schedule.items(), key=lambda x: x[1]))
726
+
727
+
728
+ # ============================================================
729
+ # Integration with Twill's DependenceGraph
730
+ # ============================================================
731
+
732
+ def twill_graph_to_gaus(
733
+ twill_graph, # twill.graph.DependenceGraph
734
+ D: Optional[int] = None,
735
+ ) -> Tuple[GausGraph, Dict[str, int]]:
736
+ """Convert a Twill DependenceGraph to GauS format.
737
+
738
+ Args:
739
+ twill_graph: Twill DependenceGraph object
740
+ D: Max depth (auto-computed if None)
741
+
742
+ Returns:
743
+ (GausGraph, name_to_index mapping)
744
+ """
745
+ V = twill_graph.V
746
+ E = twill_graph.E
747
+
748
+ # Build name -> index mapping
749
+ name_to_idx = {v.name: i for i, v in enumerate(V)}
750
+
751
+ # Separate forward edges (δ=0) from back-edges (δ>0)
752
+ forward_edges = []
753
+ back_edges = []
754
+
755
+ for e in E:
756
+ src_idx = name_to_idx[e.src]
757
+ dst_idx = name_to_idx[e.dst]
758
+ if e.iteration_delay == 0:
759
+ forward_edges.append((src_idx, dst_idx))
760
+ else:
761
+ # Back-edge: (src, dst, iteration_delay)
762
+ back_edges.append((src_idx, dst_idx, e.iteration_delay))
763
+
764
+ # Resource weights: sum of RRT usage across all FUs for each instruction
765
+ resource_weights = np.array([v.rrt.sum() for v in V], dtype=np.float64)
766
+
767
+ # Per-FU resource weights (for modulo scheduling, use the dominant FU)
768
+ memory_weights = np.array([
769
+ sum(v.memory_footprint.values()) if v.memory_footprint else 0
770
+ for v in V
771
+ ], dtype=np.float64)
772
+
773
+ node_names = [v.name for v in V]
774
+
775
+ # Auto-compute D if not given
776
+ if D is None:
777
+ D = sum(v.cycles for v in V) + len(V)
778
+
779
+ graph = GausGraph(
780
+ num_nodes=len(V),
781
+ edges=forward_edges,
782
+ back_edges=back_edges,
783
+ resource_weights=resource_weights,
784
+ memory_weights=memory_weights,
785
+ node_names=node_names,
786
+ )
787
+
788
+ return graph, name_to_idx
789
+
790
+
791
+ def gaus_solve_twill_graph(
792
+ twill_graph, # twill.graph.DependenceGraph
793
+ target_II: Optional[int] = None,
794
+ D: Optional[int] = None,
795
+ max_iters: int = 3000,
796
+ verbose: bool = True,
797
+ ) -> GausResult:
798
+ """Convenience: solve a Twill DependenceGraph using GauS.
799
+
800
+ If target_II is given, uses modulo scheduling (Formulation C).
801
+ Otherwise, uses regular scheduling (Formulation A).
802
+
803
+ Args:
804
+ twill_graph: Twill DependenceGraph
805
+ target_II: Initiation interval for modulo scheduling
806
+ D: Max depth (auto if None)
807
+ max_iters: Max optimization iterations
808
+ verbose: Print progress
809
+
810
+ Returns:
811
+ GausResult with named schedule
812
+ """
813
+ gaus_graph, name_to_idx = twill_graph_to_gaus(twill_graph, D)
814
+
815
+ solver = GauSSolver(gaus_graph, D=D or (sum(v.cycles for v in twill_graph.V) + len(twill_graph.V)))
816
+
817
+ if target_II is not None:
818
+ R_cap = 1.0 # Default: 1 resource per modulo slot (matches Twill's capacity=1)
819
+ result = solver.solve_modulo(
820
+ II=target_II,
821
+ R_cap=R_cap,
822
+ max_iters=max_iters,
823
+ verbose=verbose,
824
+ )
825
+ else:
826
+ result = solver.solve_regular(
827
+ max_iters=max_iters,
828
+ verbose=verbose,
829
+ )
830
+
831
+ return result
832
+
833
+
834
+ # ============================================================
835
+ # Synthetic graph generation for scalability testing
836
+ # ============================================================
837
+
838
+ def generate_random_dag(
839
+ num_nodes: int,
840
+ edge_density: float = 0.3,
841
+ max_weight: int = 3,
842
+ num_back_edges: int = 0,
843
+ seed: int = 42,
844
+ ) -> GausGraph:
845
+ """Generate a random DAG for benchmarking.
846
+
847
+ Args:
848
+ num_nodes: Number of operators
849
+ edge_density: Probability of edge between valid pairs
850
+ max_weight: Max resource weight per node
851
+ num_back_edges: Number of loop-carried back-edges to add
852
+ seed: Random seed
853
+
854
+ Returns:
855
+ GausGraph
856
+ """
857
+ rng = np.random.RandomState(seed)
858
+
859
+ edges = []
860
+ for i in range(num_nodes):
861
+ for j in range(i + 1, num_nodes):
862
+ if rng.random() < edge_density:
863
+ edges.append((i, j))
864
+
865
+ # Ensure connectivity: at least a chain
866
+ if not edges:
867
+ for i in range(num_nodes - 1):
868
+ edges.append((i, i + 1))
869
+
870
+ resource_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64)
871
+ memory_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64)
872
+
873
+ back_edges = []
874
+ if num_back_edges > 0:
875
+ # Add back-edges from later nodes to earlier nodes
876
+ for _ in range(num_back_edges):
877
+ src = rng.randint(num_nodes // 2, num_nodes)
878
+ dst = rng.randint(0, num_nodes // 2)
879
+ k = rng.randint(1, 3)
880
+ back_edges.append((src, dst, k))
881
+
882
+ return GausGraph(
883
+ num_nodes=num_nodes,
884
+ edges=edges,
885
+ back_edges=back_edges,
886
+ resource_weights=resource_weights,
887
+ memory_weights=memory_weights,
888
+ )