| """ |
| GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization |
| |
| Implementation of the paper by Yaohui Cai et al. (arXiv:2602.20427) |
| |
| GauS models operator scheduling as a stochastic relaxation using Gaussian |
| distributions, optimized via gradient descent with an Augmented Lagrangian |
| Method (ALM). It supports: |
| - Formulation A: Latency-constrained resource + communication optimization |
| - Formulation B: Latency-constrained memory footprint optimization |
| - Formulation C: Modulo scheduling (pipelined) — directly comparable to Twill |
| |
| Key advantages over ILP/SMT (Twill Phase 1+2): |
| - O(|V|) parameters vs O(D·|V|) for categorical approaches |
| - Scales to 10K+ operator graphs where ILP/SMT time out |
| - Exploits GPU parallelism for gradient computation |
| |
| This module integrates with Twill's DependenceGraph as an alternative solver. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| import math |
| import time |
| from typing import Dict, List, Optional, Tuple, Set |
| from dataclasses import dataclass, field |
|
|
|
|
| |
| |
| |
|
|
| def gaussian_cdf(x: torch.Tensor) -> torch.Tensor: |
| """Standard Gaussian CDF: Φ(x) = 0.5 * (1 + erf(x / sqrt(2)))""" |
| return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class GausGraph: |
| """Graph representation for GauS solver. |
| |
| Attributes: |
| num_nodes: Number of operators |V| |
| edges: List of (src, dst) forward dependency edges |
| back_edges: List of (src, dst, k) loop-carried back-edges with iteration distance k |
| resource_weights: Per-node resource demand w_i, shape [|V|] |
| memory_weights: Per-node storage bitwidth b_i, shape [|V|] |
| successors: Dict mapping node -> list of successor nodes |
| predecessors: Dict mapping node -> list of predecessor nodes |
| node_names: Optional names for nodes |
| """ |
| num_nodes: int |
| edges: List[Tuple[int, int]] |
| back_edges: List[Tuple[int, int, int]] = field(default_factory=list) |
| resource_weights: Optional[np.ndarray] = None |
| memory_weights: Optional[np.ndarray] = None |
| node_names: Optional[List[str]] = None |
| |
| def __post_init__(self): |
| |
| self.successors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)} |
| self.predecessors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)} |
| for (u, v) in self.edges: |
| self.successors[u].append(v) |
| self.predecessors[v].append(u) |
| |
| if self.resource_weights is None: |
| self.resource_weights = np.ones(self.num_nodes) |
| if self.memory_weights is None: |
| self.memory_weights = np.ones(self.num_nodes) |
| if self.node_names is None: |
| self.node_names = [f"v{i}" for i in range(self.num_nodes)] |
| |
| def topological_sort(self) -> List[int]: |
| """Kahn's algorithm for topological sort.""" |
| in_degree = [0] * self.num_nodes |
| for (u, v) in self.edges: |
| in_degree[v] += 1 |
| queue = [i for i in range(self.num_nodes) if in_degree[i] == 0] |
| order = [] |
| while queue: |
| node = queue.pop(0) |
| order.append(node) |
| for succ in self.successors[node]: |
| in_degree[succ] -= 1 |
| if in_degree[succ] == 0: |
| queue.append(succ) |
| return order |
|
|
|
|
| |
| |
| |
|
|
| def compute_asap(graph: GausGraph) -> np.ndarray: |
| """Compute As-Soon-As-Possible schedule (longest path from sources).""" |
| asap = np.zeros(graph.num_nodes, dtype=np.float64) |
| for v in graph.topological_sort(): |
| for pred in graph.predecessors[v]: |
| asap[v] = max(asap[v], asap[pred] + 1) |
| return asap |
|
|
|
|
| def compute_alap(graph: GausGraph, D: int) -> np.ndarray: |
| """Compute As-Late-As-Possible schedule (latest feasible time given depth D).""" |
| alap = np.full(graph.num_nodes, D - 1, dtype=np.float64) |
| for v in reversed(graph.topological_sort()): |
| for succ in graph.successors[v]: |
| alap[v] = min(alap[v], alap[succ] - 1) |
| return alap |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class GausResult: |
| """Result from GauS solver. |
| |
| Attributes: |
| schedule: Dict mapping node index -> scheduled time step |
| initiation_interval: II (for modulo scheduling, else None) |
| objective_value: Final objective value |
| num_violations: Number of constraint violations in final schedule |
| solve_time_seconds: Wall-clock solve time |
| iterations: Number of optimization iterations |
| loss_history: List of total loss values per iteration |
| node_names: Optional node names for display |
| """ |
| schedule: Dict[int, int] |
| initiation_interval: Optional[int] = None |
| objective_value: float = 0.0 |
| num_violations: int = 0 |
| solve_time_seconds: float = 0.0 |
| iterations: int = 0 |
| loss_history: List[float] = field(default_factory=list) |
| node_names: Optional[List[str]] = None |
| |
| @property |
| def is_feasible(self) -> bool: |
| return self.num_violations == 0 |
| |
| def named_schedule(self) -> Dict[str, int]: |
| if self.node_names: |
| return {self.node_names[i]: t for i, t in self.schedule.items()} |
| return {f"v{i}": t for i, t in self.schedule.items()} |
| |
| def __repr__(self): |
| sched_str = self.named_schedule() |
| return ( |
| f"GausResult(\n" |
| f" schedule={sched_str}\n" |
| f" II={self.initiation_interval}\n" |
| f" objective={self.objective_value:.4f}\n" |
| f" violations={self.num_violations}\n" |
| f" feasible={self.is_feasible}\n" |
| f" solve_time={self.solve_time_seconds:.2f}s\n" |
| f" iterations={self.iterations}\n" |
| f")" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class GauSSolver: |
| """Differentiable scheduling solver using Gaussian reparameterization. |
| |
| Implements Algorithm 1 from the paper with all three formulations. |
| |
| Usage: |
| solver = GauSSolver(graph, D=10) |
| result = solver.solve_regular() # Formulation A |
| result = solver.solve_modulo(II=3) # Formulation C (like Twill) |
| """ |
| |
| def __init__( |
| self, |
| graph: GausGraph, |
| D: int, |
| kappa: float = 1.0 / 6.0, |
| rho: float = 1e-4, |
| tau: float = 1e-2, |
| lr: float = 1e-2, |
| device: str = "cpu", |
| ): |
| """ |
| Args: |
| graph: The scheduling graph |
| D: Maximum schedule depth (latency bound) |
| kappa: Std init scale factor (σ = κ · (ALAP - ASAP)) |
| rho: ALM penalty coefficient |
| tau: LogSumExp temperature |
| lr: Adam learning rate |
| device: torch device |
| """ |
| self.graph = graph |
| self.D = D |
| self.kappa = kappa |
| self.rho = rho |
| self.tau = tau |
| self.lr = lr |
| self.device = device |
| self.N = graph.num_nodes |
| |
| |
| self.s_asap = compute_asap(graph) |
| self.s_alap = compute_alap(graph, D) |
| |
| |
| if graph.edges: |
| self.edge_src = torch.tensor([e[0] for e in graph.edges], dtype=torch.long, device=device) |
| self.edge_dst = torch.tensor([e[1] for e in graph.edges], dtype=torch.long, device=device) |
| else: |
| self.edge_src = torch.tensor([], dtype=torch.long, device=device) |
| self.edge_dst = torch.tensor([], dtype=torch.long, device=device) |
| |
| |
| self.w = torch.tensor(graph.resource_weights, dtype=torch.float32, device=device) |
| self.b = torch.tensor(graph.memory_weights, dtype=torch.float32, device=device) |
| |
| def _init_params(self, mu_init: Optional[np.ndarray] = None) -> Tuple[nn.Parameter, nn.Parameter]: |
| """Initialize μ and σ parameters (Section 3.1).""" |
| if mu_init is not None: |
| mu = torch.tensor(mu_init, dtype=torch.float32, device=self.device) |
| else: |
| |
| mu = torch.tensor( |
| (self.s_asap + self.s_alap) / 2.0, |
| dtype=torch.float32, device=self.device, |
| ) |
| |
| |
| freedom = self.s_alap - self.s_asap |
| sigma_init = self.kappa * np.maximum(freedom, 0.5) |
| sigma = torch.tensor(sigma_init, dtype=torch.float32, device=self.device) |
| |
| return nn.Parameter(mu), nn.Parameter(sigma) |
| |
| def _compute_P(self, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: |
| """Compute P_i^d for all (i, d). Equation 7b. |
| |
| P_i^d = Φ((d+0.5-μ_i)/σ_i) - Φ((d-0.5-μ_i)/σ_i) |
| |
| Returns: shape [N, D] |
| """ |
| sigma_safe = sigma.abs() + 1e-6 |
| d = torch.arange(self.D, dtype=torch.float32, device=self.device) |
| |
| |
| mu_exp = mu.unsqueeze(1) |
| sig_exp = sigma_safe.unsqueeze(1) |
| d_exp = d.unsqueeze(0) |
| |
| |
| |
| upper = (d_exp + 0.5 - mu_exp) / sig_exp |
| lower = (d_exp - 0.5 - mu_exp) / sig_exp |
| |
| P = gaussian_cdf(upper) - gaussian_cdf(lower) |
| |
| |
| P = P.clamp(min=1e-10, max=1.0) |
| |
| return P |
| |
| |
| |
| |
| |
| def _loss_dependency(self, P: torch.Tensor) -> torch.Tensor: |
| """Expected dependency violations (Equation 8). |
| |
| V_dep = Σ_{(i,j)∈E} Σ_{d_i=1}^{D-1} Σ_{d_j=0}^{d_i-1} P_i^{d_i} · P_j^{d_j} |
| |
| Vectorized: for each edge (i,j), V = P_i · cumsum(P_j) shifted |
| """ |
| if len(self.edge_src) == 0: |
| return torch.tensor(0.0, device=self.device) |
| |
| P_src = P[self.edge_src] |
| P_dst = P[self.edge_dst] |
| |
| |
| cdf_dst = torch.cumsum(P_dst, dim=1) |
| |
| |
| cdf_shifted = torch.cat([ |
| torch.zeros(len(self.edge_src), 1, device=self.device), |
| cdf_dst[:, :-1] |
| ], dim=1) |
| |
| |
| V_dep = torch.sum(P_src * cdf_shifted, dim=1) |
| |
| return V_dep.sum() |
| |
| def _loss_resource(self, P: torch.Tensor, R_cap: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Expected resource usage + violations (Equations 18-20). |
| |
| Res(d) = Σ_i w_i · P_i^d |
| L_res = τ · log(Σ_d exp(Res(d)/τ)) (LogSumExp smooth-max) |
| V_res = Σ_d ReLU(Res(d) - R) (violations) |
| |
| Returns: (L_res, V_res) |
| """ |
| |
| res_d = torch.matmul(self.w, P) |
| |
| |
| L_res = self.tau * torch.logsumexp(res_d / self.tau, dim=0) |
| |
| |
| if R_cap is not None: |
| V_res = torch.sum(torch.relu(res_d - R_cap)) |
| else: |
| V_res = torch.tensor(0.0, device=self.device) |
| |
| return L_res, V_res |
| |
| def _loss_communication(self, P: torch.Tensor) -> torch.Tensor: |
| """Expected communication overhead (Equation 17). |
| |
| L_com = Σ_{(i,j)∈E} Σ_{d_i} Σ_{d_j≥d_i} P_i^{d_i} · P_j^{d_j} · (d_j - d_i) |
| |
| Vectorized using expected value: E[d_j - d_i] = E[d_j] - E[d_i] (for valid pairs) |
| Simplified to: Σ_{(i,j)∈E} (μ_j - μ_i) when σ→0 |
| More precise: use the full double sum with broadcasting |
| """ |
| if len(self.edge_src) == 0: |
| return torch.tensor(0.0, device=self.device) |
| |
| P_src = P[self.edge_src] |
| P_dst = P[self.edge_dst] |
| |
| d = torch.arange(self.D, dtype=torch.float32, device=self.device) |
| |
| |
| |
| |
| E_src = torch.sum(P_src * d.unsqueeze(0), dim=1) |
| E_dst = torch.sum(P_dst * d.unsqueeze(0), dim=1) |
| |
| L_com = torch.sum(E_dst - E_src) |
| return L_com |
| |
| def _loss_modulo_resource( |
| self, P: torch.Tensor, II: int, R_cap: float |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Modulo resource usage + violations (Equations 21-22, Appendix B.3). |
| |
| P_mod[i,t] = Σ_k P_i^{t + k·II} (wrap probabilities into II slots) |
| MRes(t) = Σ_i w_i · P_mod[i,t] |
| V_mres = Σ_t ReLU(MRes(t) - R_cap) |
| """ |
| |
| P_mod = torch.zeros(self.N, II, device=self.device) |
| for k in range((self.D + II - 1) // II): |
| for t in range(II): |
| d = t + k * II |
| if d < self.D: |
| P_mod[:, t] += P[:, d] |
| |
| |
| mres_t = torch.matmul(self.w, P_mod) |
| |
| |
| L_mres = self.tau * torch.logsumexp(mres_t / self.tau, dim=0) |
| |
| |
| V_mres = torch.sum(torch.relu(mres_t - R_cap)) |
| |
| return L_mres, V_mres |
| |
| def _loss_recurrence(self, P: torch.Tensor, II: int) -> torch.Tensor: |
| """Expected recurrence violations for loop-carried deps (Equation 24). |
| |
| V_rec = Σ_{(v_i,v_j,k)∈E_B} Σ_{d_i=1}^{D-1} Σ_{d_j=d_i+k·II+1}^{D-1} P_i^{d_i} · P_j^{d_j} |
| |
| A violation occurs when s_j > s_i + k·II (consumer too late). |
| """ |
| if not self.graph.back_edges: |
| return torch.tensor(0.0, device=self.device) |
| |
| V_rec = torch.tensor(0.0, device=self.device) |
| |
| for (vi, vj, k) in self.graph.back_edges: |
| |
| |
| if vi == vj: |
| continue |
| |
| Pi = P[vi] |
| Pj = P[vj] |
| |
| |
| rcs_j = torch.flip(torch.cumsum(torch.flip(Pj, [0]), dim=0), [0]) |
| |
| |
| for d_i in range(self.D): |
| threshold = d_i + k * II + 1 |
| if threshold < self.D: |
| V_rec = V_rec + Pi[d_i] * rcs_j[threshold] |
| |
| return V_rec |
| |
| def _loss_memory(self, P: torch.Tensor) -> torch.Tensor: |
| """Expected peak memory footprint (Equations 9-11). |
| |
| Active(v_i, d) = P(X_i ≤ d) · P(max_{j∈succ(i)} X_j > d) |
| Mem(d) = Σ_i b_i · Active(v_i, d) |
| L_mem = τ · log(Σ_d exp(Mem(d)/τ)) |
| """ |
| d = torch.arange(self.D, dtype=torch.float32, device=self.device) |
| |
| |
| cdf = torch.cumsum(P, dim=1) |
| |
| |
| |
| all_succ_done = torch.ones(self.N, self.D, device=self.device) |
| for i in range(self.N): |
| for j in self.graph.successors[i]: |
| all_succ_done[i] *= cdf[j] |
| |
| |
| |
| active = cdf * (1.0 - all_succ_done) |
| |
| |
| mem_d = torch.matmul(self.b, active) |
| |
| L_mem = self.tau * torch.logsumexp(mem_d / self.tau, dim=0) |
| return L_mem |
| |
| |
| |
| |
| |
| def _legalize_regular(self, s: np.ndarray) -> np.ndarray: |
| """Algorithm 2: Regular schedule legalization via topological pass.""" |
| s_new = np.clip(s, self.s_asap, self.s_alap).astype(int) |
| |
| for v in self.graph.topological_sort(): |
| preds = self.graph.predecessors[v] |
| if preds: |
| t_req = max(s_new[p] for p in preds) + 1 |
| s_new[v] = max(s_new[v], t_req) |
| |
| return s_new |
| |
| def _legalize_modulo(self, s: np.ndarray, II: int) -> np.ndarray: |
| """Algorithm 3: Modulo schedule legalization via fixed-point iteration.""" |
| s_new = s.copy().astype(int) |
| topo = self.graph.topological_sort() |
| |
| for _ in range(self.N): |
| changed = False |
| for v in topo: |
| |
| preds = self.graph.predecessors[v] |
| t_min = max((s_new[u] for u in preds), default=-1) + 1 |
| |
| |
| |
| |
| |
| |
| t_back = 0 |
| for (vi, vj, k) in self.graph.back_edges: |
| if vj == v and vi != v: |
| t_back = max(t_back, s_new[vi] - k * II + 1) |
| |
| t_req = max(t_min, t_back) |
| if t_req > s_new[v]: |
| s_new[v] = t_req |
| changed = True |
| |
| if not changed: |
| break |
| |
| return s_new |
| |
| def _count_violations(self, s: np.ndarray, II: Optional[int] = None) -> int: |
| """Count constraint violations in a discrete schedule.""" |
| count = 0 |
| |
| for (u, v) in self.graph.edges: |
| if s[v] <= s[u]: |
| count += 1 |
| |
| if II is not None: |
| for (vi, vj, k) in self.graph.back_edges: |
| if s[vi] + k * II < s[vj]: |
| count += 1 |
| return count |
| |
| |
| |
| |
| |
| def solve_regular( |
| self, |
| max_iters: int = 2000, |
| legalize_every: int = 200, |
| alpha_com: float = 0.1, |
| R_cap: Optional[float] = None, |
| verbose: bool = True, |
| ) -> GausResult: |
| """Solve Formulation A: latency-constrained resource + communication optimization. |
| |
| Args: |
| max_iters: Maximum optimization iterations |
| legalize_every: Legalize and warm-restart every N iterations |
| alpha_com: Weight for communication objective |
| R_cap: Resource capacity limit (None = no hard cap) |
| verbose: Print progress |
| """ |
| return self._solve( |
| formulation="A", |
| max_iters=max_iters, |
| legalize_every=legalize_every, |
| alpha_com=alpha_com, |
| R_cap=R_cap, |
| II=None, |
| verbose=verbose, |
| ) |
| |
| def solve_modulo( |
| self, |
| II: int, |
| R_cap: float = 1.0, |
| max_iters: int = 3000, |
| legalize_every: int = 300, |
| verbose: bool = True, |
| ) -> GausResult: |
| """Solve Formulation C: modulo scheduling (pipelined). |
| |
| This is the formulation directly comparable to Twill's domain. |
| |
| Args: |
| II: Initiation interval (target pipeline rate) |
| R_cap: Per-slot resource capacity in modulo reservation table |
| max_iters: Maximum optimization iterations |
| legalize_every: Legalize and warm-restart every N iterations |
| verbose: Print progress |
| """ |
| return self._solve( |
| formulation="C", |
| max_iters=max_iters, |
| legalize_every=legalize_every, |
| R_cap=R_cap, |
| II=II, |
| verbose=verbose, |
| ) |
| |
| def _solve( |
| self, |
| formulation: str, |
| max_iters: int, |
| legalize_every: int, |
| alpha_com: float = 0.1, |
| R_cap: Optional[float] = None, |
| II: Optional[int] = None, |
| verbose: bool = True, |
| ) -> GausResult: |
| """Core optimization loop implementing Algorithm 1.""" |
| start_time = time.time() |
| |
| if verbose: |
| print(f"GauS Solver — Formulation {formulation}") |
| print(f" |V|={self.N}, |E|={len(self.graph.edges)}, D={self.D}") |
| if II: print(f" II={II}, R_cap={R_cap}") |
| print(f" ASAP: {self.s_asap}") |
| print(f" ALAP: {self.s_alap}") |
| |
| |
| mu, sigma = self._init_params() |
| optimizer = optim.Adam([mu, sigma], lr=self.lr) |
| |
| |
| lambda_dep = torch.tensor(1e-6, device=self.device) |
| lambda_res = torch.tensor(1e-6, device=self.device) |
| lambda_mres = torch.tensor(1e-6, device=self.device) |
| lambda_rec = torch.tensor(1e-6, device=self.device) |
| |
| loss_history = [] |
| best_schedule = None |
| best_objective = float('inf') |
| best_violations = float('inf') |
| |
| for it in range(max_iters): |
| optimizer.zero_grad() |
| |
| |
| P = self._compute_P(mu, sigma) |
| |
| |
| if formulation == "A": |
| L_res, V_res = self._loss_resource(P, R_cap) |
| L_com = self._loss_communication(P) |
| L_primary = L_res + alpha_com * L_com |
| elif formulation == "B": |
| L_primary = self._loss_memory(P) |
| V_res = torch.tensor(0.0, device=self.device) |
| elif formulation == "C": |
| L_mres, V_mres = self._loss_modulo_resource(P, II, R_cap) |
| |
| L_com = self._loss_communication(P) |
| L_primary = L_mres + 0.1 * L_com |
| V_res = torch.tensor(0.0, device=self.device) |
| else: |
| raise ValueError(f"Unknown formulation: {formulation}") |
| |
| |
| V_dep = self._loss_dependency(P) |
| |
| if formulation == "C": |
| V_rec = self._loss_recurrence(P, II) |
| else: |
| V_rec = torch.tensor(0.0, device=self.device) |
| V_mres = torch.tensor(0.0, device=self.device) |
| |
| |
| L_total = L_primary |
| L_total = L_total + lambda_dep * V_dep + (self.rho / 2) * V_dep ** 2 |
| L_total = L_total + lambda_res * V_res + (self.rho / 2) * V_res ** 2 |
| if formulation == "C": |
| L_total = L_total + lambda_mres * V_mres + (self.rho / 2) * V_mres ** 2 |
| L_total = L_total + lambda_rec * V_rec + (self.rho / 2) * V_rec ** 2 |
| |
| |
| L_total.backward() |
| optimizer.step() |
| |
| |
| with torch.no_grad(): |
| lambda_dep = lambda_dep + self.rho * V_dep.detach() |
| lambda_res = lambda_res + self.rho * V_res.detach() |
| if formulation == "C": |
| lambda_mres = lambda_mres + self.rho * V_mres.detach() |
| lambda_rec = lambda_rec + self.rho * V_rec.detach() |
| |
| loss_history.append(L_total.item()) |
| |
| |
| if (it + 1) % legalize_every == 0 or it == max_iters - 1: |
| with torch.no_grad(): |
| s_rounded = torch.round(mu).cpu().numpy().astype(int) |
| |
| if formulation == "C" and II: |
| s_legal = self._legalize_modulo(s_rounded, II) |
| else: |
| s_legal = self._legalize_regular(s_rounded) |
| |
| violations = self._count_violations(s_legal, II) |
| |
| |
| obj_val = L_primary.item() |
| if violations < best_violations or (violations == best_violations and obj_val < best_objective): |
| best_violations = violations |
| best_objective = obj_val |
| best_schedule = s_legal.copy() |
| |
| if verbose and ((it + 1) % (legalize_every) == 0): |
| sigma_mean = sigma.abs().mean().item() |
| print(f" iter {it+1:5d}: L={L_total.item():.4f}, " |
| f"V_dep={V_dep.item():.4f}, " |
| f"σ_mean={sigma_mean:.4f}, " |
| f"violations={violations}") |
| |
| |
| if violations > 0: |
| mu.data = torch.tensor( |
| s_legal.astype(np.float64), |
| dtype=torch.float32, device=self.device |
| ) |
| |
| |
| with torch.no_grad(): |
| s_final = torch.round(mu).cpu().numpy().astype(int) |
| if formulation == "C" and II: |
| s_final = self._legalize_modulo(s_final, II) |
| else: |
| s_final = self._legalize_regular(s_final) |
| |
| final_violations = self._count_violations(s_final, II) |
| |
| |
| if best_schedule is not None and final_violations > best_violations: |
| s_final = best_schedule |
| final_violations = best_violations |
| |
| solve_time = time.time() - start_time |
| |
| schedule = {i: int(s_final[i]) for i in range(self.N)} |
| |
| if verbose: |
| print(f"\n DONE in {solve_time:.2f}s, {max_iters} iterations") |
| print(f" Final schedule: {self._format_schedule(schedule)}") |
| print(f" Violations: {final_violations}") |
| if II: print(f" II: {II}") |
| |
| return GausResult( |
| schedule=schedule, |
| initiation_interval=II, |
| objective_value=best_objective, |
| num_violations=final_violations, |
| solve_time_seconds=solve_time, |
| iterations=max_iters, |
| loss_history=loss_history, |
| node_names=self.graph.node_names, |
| ) |
| |
| def _format_schedule(self, schedule: Dict[int, int]) -> str: |
| names = self.graph.node_names |
| return ", ".join(f"{names[i]}@{t}" for i, t in sorted(schedule.items(), key=lambda x: x[1])) |
|
|
|
|
| |
| |
| |
|
|
| def twill_graph_to_gaus( |
| twill_graph, |
| D: Optional[int] = None, |
| ) -> Tuple[GausGraph, Dict[str, int]]: |
| """Convert a Twill DependenceGraph to GauS format. |
| |
| Args: |
| twill_graph: Twill DependenceGraph object |
| D: Max depth (auto-computed if None) |
| |
| Returns: |
| (GausGraph, name_to_index mapping) |
| """ |
| V = twill_graph.V |
| E = twill_graph.E |
| |
| |
| name_to_idx = {v.name: i for i, v in enumerate(V)} |
| |
| |
| forward_edges = [] |
| back_edges = [] |
| |
| for e in E: |
| src_idx = name_to_idx[e.src] |
| dst_idx = name_to_idx[e.dst] |
| if e.iteration_delay == 0: |
| forward_edges.append((src_idx, dst_idx)) |
| else: |
| |
| back_edges.append((src_idx, dst_idx, e.iteration_delay)) |
| |
| |
| resource_weights = np.array([v.rrt.sum() for v in V], dtype=np.float64) |
| |
| |
| memory_weights = np.array([ |
| sum(v.memory_footprint.values()) if v.memory_footprint else 0 |
| for v in V |
| ], dtype=np.float64) |
| |
| node_names = [v.name for v in V] |
| |
| |
| if D is None: |
| D = sum(v.cycles for v in V) + len(V) |
| |
| graph = GausGraph( |
| num_nodes=len(V), |
| edges=forward_edges, |
| back_edges=back_edges, |
| resource_weights=resource_weights, |
| memory_weights=memory_weights, |
| node_names=node_names, |
| ) |
| |
| return graph, name_to_idx |
|
|
|
|
| def gaus_solve_twill_graph( |
| twill_graph, |
| target_II: Optional[int] = None, |
| D: Optional[int] = None, |
| max_iters: int = 3000, |
| verbose: bool = True, |
| ) -> GausResult: |
| """Convenience: solve a Twill DependenceGraph using GauS. |
| |
| If target_II is given, uses modulo scheduling (Formulation C). |
| Otherwise, uses regular scheduling (Formulation A). |
| |
| Args: |
| twill_graph: Twill DependenceGraph |
| target_II: Initiation interval for modulo scheduling |
| D: Max depth (auto if None) |
| max_iters: Max optimization iterations |
| verbose: Print progress |
| |
| Returns: |
| GausResult with named schedule |
| """ |
| gaus_graph, name_to_idx = twill_graph_to_gaus(twill_graph, D) |
| |
| solver = GauSSolver(gaus_graph, D=D or (sum(v.cycles for v in twill_graph.V) + len(twill_graph.V))) |
| |
| if target_II is not None: |
| R_cap = 1.0 |
| result = solver.solve_modulo( |
| II=target_II, |
| R_cap=R_cap, |
| max_iters=max_iters, |
| verbose=verbose, |
| ) |
| else: |
| result = solver.solve_regular( |
| max_iters=max_iters, |
| verbose=verbose, |
| ) |
| |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def generate_random_dag( |
| num_nodes: int, |
| edge_density: float = 0.3, |
| max_weight: int = 3, |
| num_back_edges: int = 0, |
| seed: int = 42, |
| ) -> GausGraph: |
| """Generate a random DAG for benchmarking. |
| |
| Args: |
| num_nodes: Number of operators |
| edge_density: Probability of edge between valid pairs |
| max_weight: Max resource weight per node |
| num_back_edges: Number of loop-carried back-edges to add |
| seed: Random seed |
| |
| Returns: |
| GausGraph |
| """ |
| rng = np.random.RandomState(seed) |
| |
| edges = [] |
| for i in range(num_nodes): |
| for j in range(i + 1, num_nodes): |
| if rng.random() < edge_density: |
| edges.append((i, j)) |
| |
| |
| if not edges: |
| for i in range(num_nodes - 1): |
| edges.append((i, i + 1)) |
| |
| resource_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64) |
| memory_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64) |
| |
| back_edges = [] |
| if num_back_edges > 0: |
| |
| for _ in range(num_back_edges): |
| src = rng.randint(num_nodes // 2, num_nodes) |
| dst = rng.randint(0, num_nodes // 2) |
| k = rng.randint(1, 3) |
| back_edges.append((src, dst, k)) |
| |
| return GausGraph( |
| num_nodes=num_nodes, |
| edges=edges, |
| back_edges=back_edges, |
| resource_weights=resource_weights, |
| memory_weights=memory_weights, |
| ) |
|
|