""" GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization Implementation of the paper by Yaohui Cai et al. (arXiv:2602.20427) GauS models operator scheduling as a stochastic relaxation using Gaussian distributions, optimized via gradient descent with an Augmented Lagrangian Method (ALM). It supports: - Formulation A: Latency-constrained resource + communication optimization - Formulation B: Latency-constrained memory footprint optimization - Formulation C: Modulo scheduling (pipelined) — directly comparable to Twill Key advantages over ILP/SMT (Twill Phase 1+2): - O(|V|) parameters vs O(D·|V|) for categorical approaches - Scales to 10K+ operator graphs where ILP/SMT time out - Exploits GPU parallelism for gradient computation This module integrates with Twill's DependenceGraph as an alternative solver. """ import torch import torch.nn as nn import torch.optim as optim import numpy as np import math import time from typing import Dict, List, Optional, Tuple, Set from dataclasses import dataclass, field # ============================================================ # Gaussian CDF helper # ============================================================ def gaussian_cdf(x: torch.Tensor) -> torch.Tensor: """Standard Gaussian CDF: Φ(x) = 0.5 * (1 + erf(x / sqrt(2)))""" return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) # ============================================================ # GauS Graph representation (standalone, bridges to Twill) # ============================================================ @dataclass class GausGraph: """Graph representation for GauS solver. Attributes: num_nodes: Number of operators |V| edges: List of (src, dst) forward dependency edges back_edges: List of (src, dst, k) loop-carried back-edges with iteration distance k resource_weights: Per-node resource demand w_i, shape [|V|] memory_weights: Per-node storage bitwidth b_i, shape [|V|] successors: Dict mapping node -> list of successor nodes predecessors: Dict mapping node -> list of predecessor nodes node_names: Optional names for nodes """ num_nodes: int edges: List[Tuple[int, int]] # (src, dst) forward edges back_edges: List[Tuple[int, int, int]] = field(default_factory=list) # (src, dst, k) resource_weights: Optional[np.ndarray] = None # [|V|] memory_weights: Optional[np.ndarray] = None # [|V|] node_names: Optional[List[str]] = None def __post_init__(self): # Build adjacency self.successors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)} self.predecessors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)} for (u, v) in self.edges: self.successors[u].append(v) self.predecessors[v].append(u) # Defaults if self.resource_weights is None: self.resource_weights = np.ones(self.num_nodes) if self.memory_weights is None: self.memory_weights = np.ones(self.num_nodes) if self.node_names is None: self.node_names = [f"v{i}" for i in range(self.num_nodes)] def topological_sort(self) -> List[int]: """Kahn's algorithm for topological sort.""" in_degree = [0] * self.num_nodes for (u, v) in self.edges: in_degree[v] += 1 queue = [i for i in range(self.num_nodes) if in_degree[i] == 0] order = [] while queue: node = queue.pop(0) order.append(node) for succ in self.successors[node]: in_degree[succ] -= 1 if in_degree[succ] == 0: queue.append(succ) return order # ============================================================ # ASAP / ALAP computation # ============================================================ def compute_asap(graph: GausGraph) -> np.ndarray: """Compute As-Soon-As-Possible schedule (longest path from sources).""" asap = np.zeros(graph.num_nodes, dtype=np.float64) for v in graph.topological_sort(): for pred in graph.predecessors[v]: asap[v] = max(asap[v], asap[pred] + 1) return asap def compute_alap(graph: GausGraph, D: int) -> np.ndarray: """Compute As-Late-As-Possible schedule (latest feasible time given depth D).""" alap = np.full(graph.num_nodes, D - 1, dtype=np.float64) for v in reversed(graph.topological_sort()): for succ in graph.successors[v]: alap[v] = min(alap[v], alap[succ] - 1) return alap # ============================================================ # GauS Solver Result # ============================================================ @dataclass class GausResult: """Result from GauS solver. Attributes: schedule: Dict mapping node index -> scheduled time step initiation_interval: II (for modulo scheduling, else None) objective_value: Final objective value num_violations: Number of constraint violations in final schedule solve_time_seconds: Wall-clock solve time iterations: Number of optimization iterations loss_history: List of total loss values per iteration node_names: Optional node names for display """ schedule: Dict[int, int] initiation_interval: Optional[int] = None objective_value: float = 0.0 num_violations: int = 0 solve_time_seconds: float = 0.0 iterations: int = 0 loss_history: List[float] = field(default_factory=list) node_names: Optional[List[str]] = None @property def is_feasible(self) -> bool: return self.num_violations == 0 def named_schedule(self) -> Dict[str, int]: if self.node_names: return {self.node_names[i]: t for i, t in self.schedule.items()} return {f"v{i}": t for i, t in self.schedule.items()} def __repr__(self): sched_str = self.named_schedule() return ( f"GausResult(\n" f" schedule={sched_str}\n" f" II={self.initiation_interval}\n" f" objective={self.objective_value:.4f}\n" f" violations={self.num_violations}\n" f" feasible={self.is_feasible}\n" f" solve_time={self.solve_time_seconds:.2f}s\n" f" iterations={self.iterations}\n" f")" ) # ============================================================ # Core GauS Solver # ============================================================ class GauSSolver: """Differentiable scheduling solver using Gaussian reparameterization. Implements Algorithm 1 from the paper with all three formulations. Usage: solver = GauSSolver(graph, D=10) result = solver.solve_regular() # Formulation A result = solver.solve_modulo(II=3) # Formulation C (like Twill) """ def __init__( self, graph: GausGraph, D: int, kappa: float = 1.0 / 6.0, rho: float = 1e-4, tau: float = 1e-2, lr: float = 1e-2, device: str = "cpu", ): """ Args: graph: The scheduling graph D: Maximum schedule depth (latency bound) kappa: Std init scale factor (σ = κ · (ALAP - ASAP)) rho: ALM penalty coefficient tau: LogSumExp temperature lr: Adam learning rate device: torch device """ self.graph = graph self.D = D self.kappa = kappa self.rho = rho self.tau = tau self.lr = lr self.device = device self.N = graph.num_nodes # Precompute ASAP/ALAP self.s_asap = compute_asap(graph) self.s_alap = compute_alap(graph, D) # Precompute edge tensors for vectorized loss computation if graph.edges: self.edge_src = torch.tensor([e[0] for e in graph.edges], dtype=torch.long, device=device) self.edge_dst = torch.tensor([e[1] for e in graph.edges], dtype=torch.long, device=device) else: self.edge_src = torch.tensor([], dtype=torch.long, device=device) self.edge_dst = torch.tensor([], dtype=torch.long, device=device) # Resource weights self.w = torch.tensor(graph.resource_weights, dtype=torch.float32, device=device) self.b = torch.tensor(graph.memory_weights, dtype=torch.float32, device=device) def _init_params(self, mu_init: Optional[np.ndarray] = None) -> Tuple[nn.Parameter, nn.Parameter]: """Initialize μ and σ parameters (Section 3.1).""" if mu_init is not None: mu = torch.tensor(mu_init, dtype=torch.float32, device=self.device) else: # μ₀ = (ASAP + ALAP) / 2 mu = torch.tensor( (self.s_asap + self.s_alap) / 2.0, dtype=torch.float32, device=self.device, ) # σ = κ · (ALAP - ASAP), with minimum to avoid zero freedom = self.s_alap - self.s_asap sigma_init = self.kappa * np.maximum(freedom, 0.5) sigma = torch.tensor(sigma_init, dtype=torch.float32, device=self.device) return nn.Parameter(mu), nn.Parameter(sigma) def _compute_P(self, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: """Compute P_i^d for all (i, d). Equation 7b. P_i^d = Φ((d+0.5-μ_i)/σ_i) - Φ((d-0.5-μ_i)/σ_i) Returns: shape [N, D] """ sigma_safe = sigma.abs() + 1e-6 # [N] d = torch.arange(self.D, dtype=torch.float32, device=self.device) # [D] # Broadcasting: mu [N,1], sigma [N,1], d [1,D] mu_exp = mu.unsqueeze(1) # [N, 1] sig_exp = sigma_safe.unsqueeze(1) # [N, 1] d_exp = d.unsqueeze(0) # [1, D] # Boundary handling (footnote 2 in paper): # d=0: lower bound is -inf; d=D-1: upper bound is +inf upper = (d_exp + 0.5 - mu_exp) / sig_exp # [N, D] lower = (d_exp - 0.5 - mu_exp) / sig_exp # [N, D] P = gaussian_cdf(upper) - gaussian_cdf(lower) # [N, D] # Clamp to avoid numerical issues P = P.clamp(min=1e-10, max=1.0) return P # ============================================================ # Loss functions (Section 3.2 + Appendix B) # ============================================================ def _loss_dependency(self, P: torch.Tensor) -> torch.Tensor: """Expected dependency violations (Equation 8). V_dep = Σ_{(i,j)∈E} Σ_{d_i=1}^{D-1} Σ_{d_j=0}^{d_i-1} P_i^{d_i} · P_j^{d_j} Vectorized: for each edge (i,j), V = P_i · cumsum(P_j) shifted """ if len(self.edge_src) == 0: return torch.tensor(0.0, device=self.device) P_src = P[self.edge_src] # [|E|, D] P_dst = P[self.edge_dst] # [|E|, D] # CDF_j(d) = cumsum of P_j up to d cdf_dst = torch.cumsum(P_dst, dim=1) # [|E|, D] # Shifted: CDF_j(d_i - 1) = probability consumer is before d_i cdf_shifted = torch.cat([ torch.zeros(len(self.edge_src), 1, device=self.device), cdf_dst[:, :-1] ], dim=1) # [|E|, D] # V_dep per edge = Σ_{d_i} P_i^{d_i} * CDF_j(d_i - 1) V_dep = torch.sum(P_src * cdf_shifted, dim=1) # [|E|] return V_dep.sum() def _loss_resource(self, P: torch.Tensor, R_cap: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: """Expected resource usage + violations (Equations 18-20). Res(d) = Σ_i w_i · P_i^d L_res = τ · log(Σ_d exp(Res(d)/τ)) (LogSumExp smooth-max) V_res = Σ_d ReLU(Res(d) - R) (violations) Returns: (L_res, V_res) """ # Res(d) shape [D] res_d = torch.matmul(self.w, P) # [D] # LogSumExp smooth-max L_res = self.tau * torch.logsumexp(res_d / self.tau, dim=0) # Violations if R_cap is not None: V_res = torch.sum(torch.relu(res_d - R_cap)) else: V_res = torch.tensor(0.0, device=self.device) return L_res, V_res def _loss_communication(self, P: torch.Tensor) -> torch.Tensor: """Expected communication overhead (Equation 17). L_com = Σ_{(i,j)∈E} Σ_{d_i} Σ_{d_j≥d_i} P_i^{d_i} · P_j^{d_j} · (d_j - d_i) Vectorized using expected value: E[d_j - d_i] = E[d_j] - E[d_i] (for valid pairs) Simplified to: Σ_{(i,j)∈E} (μ_j - μ_i) when σ→0 More precise: use the full double sum with broadcasting """ if len(self.edge_src) == 0: return torch.tensor(0.0, device=self.device) P_src = P[self.edge_src] # [|E|, D] P_dst = P[self.edge_dst] # [|E|, D] d = torch.arange(self.D, dtype=torch.float32, device=self.device) # [D] # E[d_j | valid] = Σ_{d_j} d_j · P_j^{d_j} # E[d_i | valid] = Σ_{d_i} d_i · P_i^{d_i} # Approximate: L_com ≈ Σ edges (E[d_j] - E[d_i]) E_src = torch.sum(P_src * d.unsqueeze(0), dim=1) # [|E|] E_dst = torch.sum(P_dst * d.unsqueeze(0), dim=1) # [|E|] L_com = torch.sum(E_dst - E_src) return L_com def _loss_modulo_resource( self, P: torch.Tensor, II: int, R_cap: float ) -> Tuple[torch.Tensor, torch.Tensor]: """Modulo resource usage + violations (Equations 21-22, Appendix B.3). P_mod[i,t] = Σ_k P_i^{t + k·II} (wrap probabilities into II slots) MRes(t) = Σ_i w_i · P_mod[i,t] V_mres = Σ_t ReLU(MRes(t) - R_cap) """ # Wrap probabilities into modulo II slots P_mod = torch.zeros(self.N, II, device=self.device) for k in range((self.D + II - 1) // II): for t in range(II): d = t + k * II if d < self.D: P_mod[:, t] += P[:, d] # MRes(t) = Σ_i w_i · P_mod[i,t], shape [II] mres_t = torch.matmul(self.w, P_mod) # LogSumExp objective L_mres = self.tau * torch.logsumexp(mres_t / self.tau, dim=0) # Violations V_mres = torch.sum(torch.relu(mres_t - R_cap)) return L_mres, V_mres def _loss_recurrence(self, P: torch.Tensor, II: int) -> torch.Tensor: """Expected recurrence violations for loop-carried deps (Equation 24). V_rec = Σ_{(v_i,v_j,k)∈E_B} Σ_{d_i=1}^{D-1} Σ_{d_j=d_i+k·II+1}^{D-1} P_i^{d_i} · P_j^{d_j} A violation occurs when s_j > s_i + k·II (consumer too late). """ if not self.graph.back_edges: return torch.tensor(0.0, device=self.device) V_rec = torch.tensor(0.0, device=self.device) for (vi, vj, k) in self.graph.back_edges: # Self-loops: constraint is II ≥ Lat(v), always satisfied if II was # chosen correctly. Skip to avoid confusing the optimizer. if vi == vj: continue Pi = P[vi] # [D] Pj = P[vj] # [D] # Reverse cumsum of Pj: rcs[d] = Σ_{d_j >= d} P_j^{d_j} rcs_j = torch.flip(torch.cumsum(torch.flip(Pj, [0]), dim=0), [0]) # Violation: d_j > d_i + k*II → rcs_j[d_i + k*II + 1] for d_i in range(self.D): threshold = d_i + k * II + 1 if threshold < self.D: V_rec = V_rec + Pi[d_i] * rcs_j[threshold] return V_rec def _loss_memory(self, P: torch.Tensor) -> torch.Tensor: """Expected peak memory footprint (Equations 9-11). Active(v_i, d) = P(X_i ≤ d) · P(max_{j∈succ(i)} X_j > d) Mem(d) = Σ_i b_i · Active(v_i, d) L_mem = τ · log(Σ_d exp(Mem(d)/τ)) """ d = torch.arange(self.D, dtype=torch.float32, device=self.device) # CDF_i(d) = Σ_{d'≤d} P_i^{d'} cdf = torch.cumsum(P, dim=1) # [N, D] # For each node, compute P(all successors finished by d) # = Π_{j∈succ(i)} CDF_j(d) all_succ_done = torch.ones(self.N, self.D, device=self.device) for i in range(self.N): for j in self.graph.successors[i]: all_succ_done[i] *= cdf[j] # Active(i, d) = CDF_i(d) · (1 - Π_{j∈succ(i)} CDF_j(d)) # (started and at least one successor hasn't finished) active = cdf * (1.0 - all_succ_done) # [N, D] # Mem(d) = Σ_i b_i · Active(i, d) mem_d = torch.matmul(self.b, active) # [D] L_mem = self.tau * torch.logsumexp(mem_d / self.tau, dim=0) return L_mem # ============================================================ # Legalization (Appendix A) # ============================================================ def _legalize_regular(self, s: np.ndarray) -> np.ndarray: """Algorithm 2: Regular schedule legalization via topological pass.""" s_new = np.clip(s, self.s_asap, self.s_alap).astype(int) for v in self.graph.topological_sort(): preds = self.graph.predecessors[v] if preds: t_req = max(s_new[p] for p in preds) + 1 s_new[v] = max(s_new[v], t_req) return s_new def _legalize_modulo(self, s: np.ndarray, II: int) -> np.ndarray: """Algorithm 3: Modulo schedule legalization via fixed-point iteration.""" s_new = s.copy().astype(int) topo = self.graph.topological_sort() for _ in range(self.N): changed = False for v in topo: # Forward dependency requirement preds = self.graph.predecessors[v] t_min = max((s_new[u] for u in preds), default=-1) + 1 # Back-edge requirement: s_i + k·II ≥ s_j + Lat(v_j) # => s_j ≤ s_i + k·II - Lat(v_j) (already scheduled) # For the current node v as consumer (vj), the constraint is: # s_v ≥ s_producer - k·II + 1 # Skip self-loops (automatically satisfied by modulo structure) t_back = 0 for (vi, vj, k) in self.graph.back_edges: if vj == v and vi != v: # v is the consumer, skip self-loops t_back = max(t_back, s_new[vi] - k * II + 1) t_req = max(t_min, t_back) if t_req > s_new[v]: s_new[v] = t_req changed = True if not changed: break return s_new def _count_violations(self, s: np.ndarray, II: Optional[int] = None) -> int: """Count constraint violations in a discrete schedule.""" count = 0 # Forward dependency violations for (u, v) in self.graph.edges: if s[v] <= s[u]: count += 1 # Recurrence violations if II is not None: for (vi, vj, k) in self.graph.back_edges: if s[vi] + k * II < s[vj]: count += 1 return count # ============================================================ # Main solve methods # ============================================================ def solve_regular( self, max_iters: int = 2000, legalize_every: int = 200, alpha_com: float = 0.1, R_cap: Optional[float] = None, verbose: bool = True, ) -> GausResult: """Solve Formulation A: latency-constrained resource + communication optimization. Args: max_iters: Maximum optimization iterations legalize_every: Legalize and warm-restart every N iterations alpha_com: Weight for communication objective R_cap: Resource capacity limit (None = no hard cap) verbose: Print progress """ return self._solve( formulation="A", max_iters=max_iters, legalize_every=legalize_every, alpha_com=alpha_com, R_cap=R_cap, II=None, verbose=verbose, ) def solve_modulo( self, II: int, R_cap: float = 1.0, max_iters: int = 3000, legalize_every: int = 300, verbose: bool = True, ) -> GausResult: """Solve Formulation C: modulo scheduling (pipelined). This is the formulation directly comparable to Twill's domain. Args: II: Initiation interval (target pipeline rate) R_cap: Per-slot resource capacity in modulo reservation table max_iters: Maximum optimization iterations legalize_every: Legalize and warm-restart every N iterations verbose: Print progress """ return self._solve( formulation="C", max_iters=max_iters, legalize_every=legalize_every, R_cap=R_cap, II=II, verbose=verbose, ) def _solve( self, formulation: str, max_iters: int, legalize_every: int, alpha_com: float = 0.1, R_cap: Optional[float] = None, II: Optional[int] = None, verbose: bool = True, ) -> GausResult: """Core optimization loop implementing Algorithm 1.""" start_time = time.time() if verbose: print(f"GauS Solver — Formulation {formulation}") print(f" |V|={self.N}, |E|={len(self.graph.edges)}, D={self.D}") if II: print(f" II={II}, R_cap={R_cap}") print(f" ASAP: {self.s_asap}") print(f" ALAP: {self.s_alap}") # Initialize parameters mu, sigma = self._init_params() optimizer = optim.Adam([mu, sigma], lr=self.lr) # Lagrange multipliers (ALM) lambda_dep = torch.tensor(1e-6, device=self.device) lambda_res = torch.tensor(1e-6, device=self.device) lambda_mres = torch.tensor(1e-6, device=self.device) lambda_rec = torch.tensor(1e-6, device=self.device) loss_history = [] best_schedule = None best_objective = float('inf') best_violations = float('inf') for it in range(max_iters): optimizer.zero_grad() # Compute P_i^d P = self._compute_P(mu, sigma) # Primary objectives if formulation == "A": L_res, V_res = self._loss_resource(P, R_cap) L_com = self._loss_communication(P) L_primary = L_res + alpha_com * L_com elif formulation == "B": L_primary = self._loss_memory(P) V_res = torch.tensor(0.0, device=self.device) elif formulation == "C": L_mres, V_mres = self._loss_modulo_resource(P, II, R_cap) # Add compactness term: encourage short schedules L_com = self._loss_communication(P) L_primary = L_mres + 0.1 * L_com V_res = torch.tensor(0.0, device=self.device) else: raise ValueError(f"Unknown formulation: {formulation}") # Constraint violations V_dep = self._loss_dependency(P) if formulation == "C": V_rec = self._loss_recurrence(P, II) else: V_rec = torch.tensor(0.0, device=self.device) V_mres = torch.tensor(0.0, device=self.device) # Augmented Lagrangian L_total = L_primary L_total = L_total + lambda_dep * V_dep + (self.rho / 2) * V_dep ** 2 L_total = L_total + lambda_res * V_res + (self.rho / 2) * V_res ** 2 if formulation == "C": L_total = L_total + lambda_mres * V_mres + (self.rho / 2) * V_mres ** 2 L_total = L_total + lambda_rec * V_rec + (self.rho / 2) * V_rec ** 2 # Backward + optimize L_total.backward() optimizer.step() # Update Lagrange multipliers with torch.no_grad(): lambda_dep = lambda_dep + self.rho * V_dep.detach() lambda_res = lambda_res + self.rho * V_res.detach() if formulation == "C": lambda_mres = lambda_mres + self.rho * V_mres.detach() lambda_rec = lambda_rec + self.rho * V_rec.detach() loss_history.append(L_total.item()) # Periodic legalization + warm restart if (it + 1) % legalize_every == 0 or it == max_iters - 1: with torch.no_grad(): s_rounded = torch.round(mu).cpu().numpy().astype(int) if formulation == "C" and II: s_legal = self._legalize_modulo(s_rounded, II) else: s_legal = self._legalize_regular(s_rounded) violations = self._count_violations(s_legal, II) # Track best obj_val = L_primary.item() if violations < best_violations or (violations == best_violations and obj_val < best_objective): best_violations = violations best_objective = obj_val best_schedule = s_legal.copy() if verbose and ((it + 1) % (legalize_every) == 0): sigma_mean = sigma.abs().mean().item() print(f" iter {it+1:5d}: L={L_total.item():.4f}, " f"V_dep={V_dep.item():.4f}, " f"σ_mean={sigma_mean:.4f}, " f"violations={violations}") # Warm restart: re-initialize μ from legalized schedule if violations > 0: mu.data = torch.tensor( s_legal.astype(np.float64), dtype=torch.float32, device=self.device ) # Final extraction with torch.no_grad(): s_final = torch.round(mu).cpu().numpy().astype(int) if formulation == "C" and II: s_final = self._legalize_modulo(s_final, II) else: s_final = self._legalize_regular(s_final) final_violations = self._count_violations(s_final, II) # Use best if final has more violations if best_schedule is not None and final_violations > best_violations: s_final = best_schedule final_violations = best_violations solve_time = time.time() - start_time schedule = {i: int(s_final[i]) for i in range(self.N)} if verbose: print(f"\n DONE in {solve_time:.2f}s, {max_iters} iterations") print(f" Final schedule: {self._format_schedule(schedule)}") print(f" Violations: {final_violations}") if II: print(f" II: {II}") return GausResult( schedule=schedule, initiation_interval=II, objective_value=best_objective, num_violations=final_violations, solve_time_seconds=solve_time, iterations=max_iters, loss_history=loss_history, node_names=self.graph.node_names, ) def _format_schedule(self, schedule: Dict[int, int]) -> str: names = self.graph.node_names return ", ".join(f"{names[i]}@{t}" for i, t in sorted(schedule.items(), key=lambda x: x[1])) # ============================================================ # Integration with Twill's DependenceGraph # ============================================================ def twill_graph_to_gaus( twill_graph, # twill.graph.DependenceGraph D: Optional[int] = None, ) -> Tuple[GausGraph, Dict[str, int]]: """Convert a Twill DependenceGraph to GauS format. Args: twill_graph: Twill DependenceGraph object D: Max depth (auto-computed if None) Returns: (GausGraph, name_to_index mapping) """ V = twill_graph.V E = twill_graph.E # Build name -> index mapping name_to_idx = {v.name: i for i, v in enumerate(V)} # Separate forward edges (δ=0) from back-edges (δ>0) forward_edges = [] back_edges = [] for e in E: src_idx = name_to_idx[e.src] dst_idx = name_to_idx[e.dst] if e.iteration_delay == 0: forward_edges.append((src_idx, dst_idx)) else: # Back-edge: (src, dst, iteration_delay) back_edges.append((src_idx, dst_idx, e.iteration_delay)) # Resource weights: sum of RRT usage across all FUs for each instruction resource_weights = np.array([v.rrt.sum() for v in V], dtype=np.float64) # Per-FU resource weights (for modulo scheduling, use the dominant FU) memory_weights = np.array([ sum(v.memory_footprint.values()) if v.memory_footprint else 0 for v in V ], dtype=np.float64) node_names = [v.name for v in V] # Auto-compute D if not given if D is None: D = sum(v.cycles for v in V) + len(V) graph = GausGraph( num_nodes=len(V), edges=forward_edges, back_edges=back_edges, resource_weights=resource_weights, memory_weights=memory_weights, node_names=node_names, ) return graph, name_to_idx def gaus_solve_twill_graph( twill_graph, # twill.graph.DependenceGraph target_II: Optional[int] = None, D: Optional[int] = None, max_iters: int = 3000, verbose: bool = True, ) -> GausResult: """Convenience: solve a Twill DependenceGraph using GauS. If target_II is given, uses modulo scheduling (Formulation C). Otherwise, uses regular scheduling (Formulation A). Args: twill_graph: Twill DependenceGraph target_II: Initiation interval for modulo scheduling D: Max depth (auto if None) max_iters: Max optimization iterations verbose: Print progress Returns: GausResult with named schedule """ gaus_graph, name_to_idx = twill_graph_to_gaus(twill_graph, D) solver = GauSSolver(gaus_graph, D=D or (sum(v.cycles for v in twill_graph.V) + len(twill_graph.V))) if target_II is not None: R_cap = 1.0 # Default: 1 resource per modulo slot (matches Twill's capacity=1) result = solver.solve_modulo( II=target_II, R_cap=R_cap, max_iters=max_iters, verbose=verbose, ) else: result = solver.solve_regular( max_iters=max_iters, verbose=verbose, ) return result # ============================================================ # Synthetic graph generation for scalability testing # ============================================================ def generate_random_dag( num_nodes: int, edge_density: float = 0.3, max_weight: int = 3, num_back_edges: int = 0, seed: int = 42, ) -> GausGraph: """Generate a random DAG for benchmarking. Args: num_nodes: Number of operators edge_density: Probability of edge between valid pairs max_weight: Max resource weight per node num_back_edges: Number of loop-carried back-edges to add seed: Random seed Returns: GausGraph """ rng = np.random.RandomState(seed) edges = [] for i in range(num_nodes): for j in range(i + 1, num_nodes): if rng.random() < edge_density: edges.append((i, j)) # Ensure connectivity: at least a chain if not edges: for i in range(num_nodes - 1): edges.append((i, i + 1)) resource_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64) memory_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64) back_edges = [] if num_back_edges > 0: # Add back-edges from later nodes to earlier nodes for _ in range(num_back_edges): src = rng.randint(num_nodes // 2, num_nodes) dst = rng.randint(0, num_nodes // 2) k = rng.randint(1, 3) back_edges.append((src, dst, k)) return GausGraph( num_nodes=num_nodes, edges=edges, back_edges=back_edges, resource_weights=resource_weights, memory_weights=memory_weights, )