twill-swp-ws / twill /gaus_solver.py
AshenNav's picture
Upload twill/gaus_solver.py with huggingface_hub
6cf8dee verified
"""
GauS: Differentiable Scheduling Optimization via Gaussian Reparameterization
Implementation of the paper by Yaohui Cai et al. (arXiv:2602.20427)
GauS models operator scheduling as a stochastic relaxation using Gaussian
distributions, optimized via gradient descent with an Augmented Lagrangian
Method (ALM). It supports:
- Formulation A: Latency-constrained resource + communication optimization
- Formulation B: Latency-constrained memory footprint optimization
- Formulation C: Modulo scheduling (pipelined) — directly comparable to Twill
Key advantages over ILP/SMT (Twill Phase 1+2):
- O(|V|) parameters vs O(D·|V|) for categorical approaches
- Scales to 10K+ operator graphs where ILP/SMT time out
- Exploits GPU parallelism for gradient computation
This module integrates with Twill's DependenceGraph as an alternative solver.
"""
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import time
from typing import Dict, List, Optional, Tuple, Set
from dataclasses import dataclass, field
# ============================================================
# Gaussian CDF helper
# ============================================================
def gaussian_cdf(x: torch.Tensor) -> torch.Tensor:
"""Standard Gaussian CDF: Φ(x) = 0.5 * (1 + erf(x / sqrt(2)))"""
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# ============================================================
# GauS Graph representation (standalone, bridges to Twill)
# ============================================================
@dataclass
class GausGraph:
"""Graph representation for GauS solver.
Attributes:
num_nodes: Number of operators |V|
edges: List of (src, dst) forward dependency edges
back_edges: List of (src, dst, k) loop-carried back-edges with iteration distance k
resource_weights: Per-node resource demand w_i, shape [|V|]
memory_weights: Per-node storage bitwidth b_i, shape [|V|]
successors: Dict mapping node -> list of successor nodes
predecessors: Dict mapping node -> list of predecessor nodes
node_names: Optional names for nodes
"""
num_nodes: int
edges: List[Tuple[int, int]] # (src, dst) forward edges
back_edges: List[Tuple[int, int, int]] = field(default_factory=list) # (src, dst, k)
resource_weights: Optional[np.ndarray] = None # [|V|]
memory_weights: Optional[np.ndarray] = None # [|V|]
node_names: Optional[List[str]] = None
def __post_init__(self):
# Build adjacency
self.successors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)}
self.predecessors: Dict[int, List[int]] = {i: [] for i in range(self.num_nodes)}
for (u, v) in self.edges:
self.successors[u].append(v)
self.predecessors[v].append(u)
# Defaults
if self.resource_weights is None:
self.resource_weights = np.ones(self.num_nodes)
if self.memory_weights is None:
self.memory_weights = np.ones(self.num_nodes)
if self.node_names is None:
self.node_names = [f"v{i}" for i in range(self.num_nodes)]
def topological_sort(self) -> List[int]:
"""Kahn's algorithm for topological sort."""
in_degree = [0] * self.num_nodes
for (u, v) in self.edges:
in_degree[v] += 1
queue = [i for i in range(self.num_nodes) if in_degree[i] == 0]
order = []
while queue:
node = queue.pop(0)
order.append(node)
for succ in self.successors[node]:
in_degree[succ] -= 1
if in_degree[succ] == 0:
queue.append(succ)
return order
# ============================================================
# ASAP / ALAP computation
# ============================================================
def compute_asap(graph: GausGraph) -> np.ndarray:
"""Compute As-Soon-As-Possible schedule (longest path from sources)."""
asap = np.zeros(graph.num_nodes, dtype=np.float64)
for v in graph.topological_sort():
for pred in graph.predecessors[v]:
asap[v] = max(asap[v], asap[pred] + 1)
return asap
def compute_alap(graph: GausGraph, D: int) -> np.ndarray:
"""Compute As-Late-As-Possible schedule (latest feasible time given depth D)."""
alap = np.full(graph.num_nodes, D - 1, dtype=np.float64)
for v in reversed(graph.topological_sort()):
for succ in graph.successors[v]:
alap[v] = min(alap[v], alap[succ] - 1)
return alap
# ============================================================
# GauS Solver Result
# ============================================================
@dataclass
class GausResult:
"""Result from GauS solver.
Attributes:
schedule: Dict mapping node index -> scheduled time step
initiation_interval: II (for modulo scheduling, else None)
objective_value: Final objective value
num_violations: Number of constraint violations in final schedule
solve_time_seconds: Wall-clock solve time
iterations: Number of optimization iterations
loss_history: List of total loss values per iteration
node_names: Optional node names for display
"""
schedule: Dict[int, int]
initiation_interval: Optional[int] = None
objective_value: float = 0.0
num_violations: int = 0
solve_time_seconds: float = 0.0
iterations: int = 0
loss_history: List[float] = field(default_factory=list)
node_names: Optional[List[str]] = None
@property
def is_feasible(self) -> bool:
return self.num_violations == 0
def named_schedule(self) -> Dict[str, int]:
if self.node_names:
return {self.node_names[i]: t for i, t in self.schedule.items()}
return {f"v{i}": t for i, t in self.schedule.items()}
def __repr__(self):
sched_str = self.named_schedule()
return (
f"GausResult(\n"
f" schedule={sched_str}\n"
f" II={self.initiation_interval}\n"
f" objective={self.objective_value:.4f}\n"
f" violations={self.num_violations}\n"
f" feasible={self.is_feasible}\n"
f" solve_time={self.solve_time_seconds:.2f}s\n"
f" iterations={self.iterations}\n"
f")"
)
# ============================================================
# Core GauS Solver
# ============================================================
class GauSSolver:
"""Differentiable scheduling solver using Gaussian reparameterization.
Implements Algorithm 1 from the paper with all three formulations.
Usage:
solver = GauSSolver(graph, D=10)
result = solver.solve_regular() # Formulation A
result = solver.solve_modulo(II=3) # Formulation C (like Twill)
"""
def __init__(
self,
graph: GausGraph,
D: int,
kappa: float = 1.0 / 6.0,
rho: float = 1e-4,
tau: float = 1e-2,
lr: float = 1e-2,
device: str = "cpu",
):
"""
Args:
graph: The scheduling graph
D: Maximum schedule depth (latency bound)
kappa: Std init scale factor (σ = κ · (ALAP - ASAP))
rho: ALM penalty coefficient
tau: LogSumExp temperature
lr: Adam learning rate
device: torch device
"""
self.graph = graph
self.D = D
self.kappa = kappa
self.rho = rho
self.tau = tau
self.lr = lr
self.device = device
self.N = graph.num_nodes
# Precompute ASAP/ALAP
self.s_asap = compute_asap(graph)
self.s_alap = compute_alap(graph, D)
# Precompute edge tensors for vectorized loss computation
if graph.edges:
self.edge_src = torch.tensor([e[0] for e in graph.edges], dtype=torch.long, device=device)
self.edge_dst = torch.tensor([e[1] for e in graph.edges], dtype=torch.long, device=device)
else:
self.edge_src = torch.tensor([], dtype=torch.long, device=device)
self.edge_dst = torch.tensor([], dtype=torch.long, device=device)
# Resource weights
self.w = torch.tensor(graph.resource_weights, dtype=torch.float32, device=device)
self.b = torch.tensor(graph.memory_weights, dtype=torch.float32, device=device)
def _init_params(self, mu_init: Optional[np.ndarray] = None) -> Tuple[nn.Parameter, nn.Parameter]:
"""Initialize μ and σ parameters (Section 3.1)."""
if mu_init is not None:
mu = torch.tensor(mu_init, dtype=torch.float32, device=self.device)
else:
# μ₀ = (ASAP + ALAP) / 2
mu = torch.tensor(
(self.s_asap + self.s_alap) / 2.0,
dtype=torch.float32, device=self.device,
)
# σ = κ · (ALAP - ASAP), with minimum to avoid zero
freedom = self.s_alap - self.s_asap
sigma_init = self.kappa * np.maximum(freedom, 0.5)
sigma = torch.tensor(sigma_init, dtype=torch.float32, device=self.device)
return nn.Parameter(mu), nn.Parameter(sigma)
def _compute_P(self, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
"""Compute P_i^d for all (i, d). Equation 7b.
P_i^d = Φ((d+0.5-μ_i)/σ_i) - Φ((d-0.5-μ_i)/σ_i)
Returns: shape [N, D]
"""
sigma_safe = sigma.abs() + 1e-6 # [N]
d = torch.arange(self.D, dtype=torch.float32, device=self.device) # [D]
# Broadcasting: mu [N,1], sigma [N,1], d [1,D]
mu_exp = mu.unsqueeze(1) # [N, 1]
sig_exp = sigma_safe.unsqueeze(1) # [N, 1]
d_exp = d.unsqueeze(0) # [1, D]
# Boundary handling (footnote 2 in paper):
# d=0: lower bound is -inf; d=D-1: upper bound is +inf
upper = (d_exp + 0.5 - mu_exp) / sig_exp # [N, D]
lower = (d_exp - 0.5 - mu_exp) / sig_exp # [N, D]
P = gaussian_cdf(upper) - gaussian_cdf(lower) # [N, D]
# Clamp to avoid numerical issues
P = P.clamp(min=1e-10, max=1.0)
return P
# ============================================================
# Loss functions (Section 3.2 + Appendix B)
# ============================================================
def _loss_dependency(self, P: torch.Tensor) -> torch.Tensor:
"""Expected dependency violations (Equation 8).
V_dep = Σ_{(i,j)∈E} Σ_{d_i=1}^{D-1} Σ_{d_j=0}^{d_i-1} P_i^{d_i} · P_j^{d_j}
Vectorized: for each edge (i,j), V = P_i · cumsum(P_j) shifted
"""
if len(self.edge_src) == 0:
return torch.tensor(0.0, device=self.device)
P_src = P[self.edge_src] # [|E|, D]
P_dst = P[self.edge_dst] # [|E|, D]
# CDF_j(d) = cumsum of P_j up to d
cdf_dst = torch.cumsum(P_dst, dim=1) # [|E|, D]
# Shifted: CDF_j(d_i - 1) = probability consumer is before d_i
cdf_shifted = torch.cat([
torch.zeros(len(self.edge_src), 1, device=self.device),
cdf_dst[:, :-1]
], dim=1) # [|E|, D]
# V_dep per edge = Σ_{d_i} P_i^{d_i} * CDF_j(d_i - 1)
V_dep = torch.sum(P_src * cdf_shifted, dim=1) # [|E|]
return V_dep.sum()
def _loss_resource(self, P: torch.Tensor, R_cap: Optional[float] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Expected resource usage + violations (Equations 18-20).
Res(d) = Σ_i w_i · P_i^d
L_res = τ · log(Σ_d exp(Res(d)/τ)) (LogSumExp smooth-max)
V_res = Σ_d ReLU(Res(d) - R) (violations)
Returns: (L_res, V_res)
"""
# Res(d) shape [D]
res_d = torch.matmul(self.w, P) # [D]
# LogSumExp smooth-max
L_res = self.tau * torch.logsumexp(res_d / self.tau, dim=0)
# Violations
if R_cap is not None:
V_res = torch.sum(torch.relu(res_d - R_cap))
else:
V_res = torch.tensor(0.0, device=self.device)
return L_res, V_res
def _loss_communication(self, P: torch.Tensor) -> torch.Tensor:
"""Expected communication overhead (Equation 17).
L_com = Σ_{(i,j)∈E} Σ_{d_i} Σ_{d_j≥d_i} P_i^{d_i} · P_j^{d_j} · (d_j - d_i)
Vectorized using expected value: E[d_j - d_i] = E[d_j] - E[d_i] (for valid pairs)
Simplified to: Σ_{(i,j)∈E} (μ_j - μ_i) when σ→0
More precise: use the full double sum with broadcasting
"""
if len(self.edge_src) == 0:
return torch.tensor(0.0, device=self.device)
P_src = P[self.edge_src] # [|E|, D]
P_dst = P[self.edge_dst] # [|E|, D]
d = torch.arange(self.D, dtype=torch.float32, device=self.device) # [D]
# E[d_j | valid] = Σ_{d_j} d_j · P_j^{d_j}
# E[d_i | valid] = Σ_{d_i} d_i · P_i^{d_i}
# Approximate: L_com ≈ Σ edges (E[d_j] - E[d_i])
E_src = torch.sum(P_src * d.unsqueeze(0), dim=1) # [|E|]
E_dst = torch.sum(P_dst * d.unsqueeze(0), dim=1) # [|E|]
L_com = torch.sum(E_dst - E_src)
return L_com
def _loss_modulo_resource(
self, P: torch.Tensor, II: int, R_cap: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Modulo resource usage + violations (Equations 21-22, Appendix B.3).
P_mod[i,t] = Σ_k P_i^{t + k·II} (wrap probabilities into II slots)
MRes(t) = Σ_i w_i · P_mod[i,t]
V_mres = Σ_t ReLU(MRes(t) - R_cap)
"""
# Wrap probabilities into modulo II slots
P_mod = torch.zeros(self.N, II, device=self.device)
for k in range((self.D + II - 1) // II):
for t in range(II):
d = t + k * II
if d < self.D:
P_mod[:, t] += P[:, d]
# MRes(t) = Σ_i w_i · P_mod[i,t], shape [II]
mres_t = torch.matmul(self.w, P_mod)
# LogSumExp objective
L_mres = self.tau * torch.logsumexp(mres_t / self.tau, dim=0)
# Violations
V_mres = torch.sum(torch.relu(mres_t - R_cap))
return L_mres, V_mres
def _loss_recurrence(self, P: torch.Tensor, II: int) -> torch.Tensor:
"""Expected recurrence violations for loop-carried deps (Equation 24).
V_rec = Σ_{(v_i,v_j,k)∈E_B} Σ_{d_i=1}^{D-1} Σ_{d_j=d_i+k·II+1}^{D-1} P_i^{d_i} · P_j^{d_j}
A violation occurs when s_j > s_i + k·II (consumer too late).
"""
if not self.graph.back_edges:
return torch.tensor(0.0, device=self.device)
V_rec = torch.tensor(0.0, device=self.device)
for (vi, vj, k) in self.graph.back_edges:
# Self-loops: constraint is II ≥ Lat(v), always satisfied if II was
# chosen correctly. Skip to avoid confusing the optimizer.
if vi == vj:
continue
Pi = P[vi] # [D]
Pj = P[vj] # [D]
# Reverse cumsum of Pj: rcs[d] = Σ_{d_j >= d} P_j^{d_j}
rcs_j = torch.flip(torch.cumsum(torch.flip(Pj, [0]), dim=0), [0])
# Violation: d_j > d_i + k*II → rcs_j[d_i + k*II + 1]
for d_i in range(self.D):
threshold = d_i + k * II + 1
if threshold < self.D:
V_rec = V_rec + Pi[d_i] * rcs_j[threshold]
return V_rec
def _loss_memory(self, P: torch.Tensor) -> torch.Tensor:
"""Expected peak memory footprint (Equations 9-11).
Active(v_i, d) = P(X_i ≤ d) · P(max_{j∈succ(i)} X_j > d)
Mem(d) = Σ_i b_i · Active(v_i, d)
L_mem = τ · log(Σ_d exp(Mem(d)/τ))
"""
d = torch.arange(self.D, dtype=torch.float32, device=self.device)
# CDF_i(d) = Σ_{d'≤d} P_i^{d'}
cdf = torch.cumsum(P, dim=1) # [N, D]
# For each node, compute P(all successors finished by d)
# = Π_{j∈succ(i)} CDF_j(d)
all_succ_done = torch.ones(self.N, self.D, device=self.device)
for i in range(self.N):
for j in self.graph.successors[i]:
all_succ_done[i] *= cdf[j]
# Active(i, d) = CDF_i(d) · (1 - Π_{j∈succ(i)} CDF_j(d))
# (started and at least one successor hasn't finished)
active = cdf * (1.0 - all_succ_done) # [N, D]
# Mem(d) = Σ_i b_i · Active(i, d)
mem_d = torch.matmul(self.b, active) # [D]
L_mem = self.tau * torch.logsumexp(mem_d / self.tau, dim=0)
return L_mem
# ============================================================
# Legalization (Appendix A)
# ============================================================
def _legalize_regular(self, s: np.ndarray) -> np.ndarray:
"""Algorithm 2: Regular schedule legalization via topological pass."""
s_new = np.clip(s, self.s_asap, self.s_alap).astype(int)
for v in self.graph.topological_sort():
preds = self.graph.predecessors[v]
if preds:
t_req = max(s_new[p] for p in preds) + 1
s_new[v] = max(s_new[v], t_req)
return s_new
def _legalize_modulo(self, s: np.ndarray, II: int) -> np.ndarray:
"""Algorithm 3: Modulo schedule legalization via fixed-point iteration."""
s_new = s.copy().astype(int)
topo = self.graph.topological_sort()
for _ in range(self.N):
changed = False
for v in topo:
# Forward dependency requirement
preds = self.graph.predecessors[v]
t_min = max((s_new[u] for u in preds), default=-1) + 1
# Back-edge requirement: s_i + k·II ≥ s_j + Lat(v_j)
# => s_j ≤ s_i + k·II - Lat(v_j) (already scheduled)
# For the current node v as consumer (vj), the constraint is:
# s_v ≥ s_producer - k·II + 1
# Skip self-loops (automatically satisfied by modulo structure)
t_back = 0
for (vi, vj, k) in self.graph.back_edges:
if vj == v and vi != v: # v is the consumer, skip self-loops
t_back = max(t_back, s_new[vi] - k * II + 1)
t_req = max(t_min, t_back)
if t_req > s_new[v]:
s_new[v] = t_req
changed = True
if not changed:
break
return s_new
def _count_violations(self, s: np.ndarray, II: Optional[int] = None) -> int:
"""Count constraint violations in a discrete schedule."""
count = 0
# Forward dependency violations
for (u, v) in self.graph.edges:
if s[v] <= s[u]:
count += 1
# Recurrence violations
if II is not None:
for (vi, vj, k) in self.graph.back_edges:
if s[vi] + k * II < s[vj]:
count += 1
return count
# ============================================================
# Main solve methods
# ============================================================
def solve_regular(
self,
max_iters: int = 2000,
legalize_every: int = 200,
alpha_com: float = 0.1,
R_cap: Optional[float] = None,
verbose: bool = True,
) -> GausResult:
"""Solve Formulation A: latency-constrained resource + communication optimization.
Args:
max_iters: Maximum optimization iterations
legalize_every: Legalize and warm-restart every N iterations
alpha_com: Weight for communication objective
R_cap: Resource capacity limit (None = no hard cap)
verbose: Print progress
"""
return self._solve(
formulation="A",
max_iters=max_iters,
legalize_every=legalize_every,
alpha_com=alpha_com,
R_cap=R_cap,
II=None,
verbose=verbose,
)
def solve_modulo(
self,
II: int,
R_cap: float = 1.0,
max_iters: int = 3000,
legalize_every: int = 300,
verbose: bool = True,
) -> GausResult:
"""Solve Formulation C: modulo scheduling (pipelined).
This is the formulation directly comparable to Twill's domain.
Args:
II: Initiation interval (target pipeline rate)
R_cap: Per-slot resource capacity in modulo reservation table
max_iters: Maximum optimization iterations
legalize_every: Legalize and warm-restart every N iterations
verbose: Print progress
"""
return self._solve(
formulation="C",
max_iters=max_iters,
legalize_every=legalize_every,
R_cap=R_cap,
II=II,
verbose=verbose,
)
def _solve(
self,
formulation: str,
max_iters: int,
legalize_every: int,
alpha_com: float = 0.1,
R_cap: Optional[float] = None,
II: Optional[int] = None,
verbose: bool = True,
) -> GausResult:
"""Core optimization loop implementing Algorithm 1."""
start_time = time.time()
if verbose:
print(f"GauS Solver — Formulation {formulation}")
print(f" |V|={self.N}, |E|={len(self.graph.edges)}, D={self.D}")
if II: print(f" II={II}, R_cap={R_cap}")
print(f" ASAP: {self.s_asap}")
print(f" ALAP: {self.s_alap}")
# Initialize parameters
mu, sigma = self._init_params()
optimizer = optim.Adam([mu, sigma], lr=self.lr)
# Lagrange multipliers (ALM)
lambda_dep = torch.tensor(1e-6, device=self.device)
lambda_res = torch.tensor(1e-6, device=self.device)
lambda_mres = torch.tensor(1e-6, device=self.device)
lambda_rec = torch.tensor(1e-6, device=self.device)
loss_history = []
best_schedule = None
best_objective = float('inf')
best_violations = float('inf')
for it in range(max_iters):
optimizer.zero_grad()
# Compute P_i^d
P = self._compute_P(mu, sigma)
# Primary objectives
if formulation == "A":
L_res, V_res = self._loss_resource(P, R_cap)
L_com = self._loss_communication(P)
L_primary = L_res + alpha_com * L_com
elif formulation == "B":
L_primary = self._loss_memory(P)
V_res = torch.tensor(0.0, device=self.device)
elif formulation == "C":
L_mres, V_mres = self._loss_modulo_resource(P, II, R_cap)
# Add compactness term: encourage short schedules
L_com = self._loss_communication(P)
L_primary = L_mres + 0.1 * L_com
V_res = torch.tensor(0.0, device=self.device)
else:
raise ValueError(f"Unknown formulation: {formulation}")
# Constraint violations
V_dep = self._loss_dependency(P)
if formulation == "C":
V_rec = self._loss_recurrence(P, II)
else:
V_rec = torch.tensor(0.0, device=self.device)
V_mres = torch.tensor(0.0, device=self.device)
# Augmented Lagrangian
L_total = L_primary
L_total = L_total + lambda_dep * V_dep + (self.rho / 2) * V_dep ** 2
L_total = L_total + lambda_res * V_res + (self.rho / 2) * V_res ** 2
if formulation == "C":
L_total = L_total + lambda_mres * V_mres + (self.rho / 2) * V_mres ** 2
L_total = L_total + lambda_rec * V_rec + (self.rho / 2) * V_rec ** 2
# Backward + optimize
L_total.backward()
optimizer.step()
# Update Lagrange multipliers
with torch.no_grad():
lambda_dep = lambda_dep + self.rho * V_dep.detach()
lambda_res = lambda_res + self.rho * V_res.detach()
if formulation == "C":
lambda_mres = lambda_mres + self.rho * V_mres.detach()
lambda_rec = lambda_rec + self.rho * V_rec.detach()
loss_history.append(L_total.item())
# Periodic legalization + warm restart
if (it + 1) % legalize_every == 0 or it == max_iters - 1:
with torch.no_grad():
s_rounded = torch.round(mu).cpu().numpy().astype(int)
if formulation == "C" and II:
s_legal = self._legalize_modulo(s_rounded, II)
else:
s_legal = self._legalize_regular(s_rounded)
violations = self._count_violations(s_legal, II)
# Track best
obj_val = L_primary.item()
if violations < best_violations or (violations == best_violations and obj_val < best_objective):
best_violations = violations
best_objective = obj_val
best_schedule = s_legal.copy()
if verbose and ((it + 1) % (legalize_every) == 0):
sigma_mean = sigma.abs().mean().item()
print(f" iter {it+1:5d}: L={L_total.item():.4f}, "
f"V_dep={V_dep.item():.4f}, "
f"σ_mean={sigma_mean:.4f}, "
f"violations={violations}")
# Warm restart: re-initialize μ from legalized schedule
if violations > 0:
mu.data = torch.tensor(
s_legal.astype(np.float64),
dtype=torch.float32, device=self.device
)
# Final extraction
with torch.no_grad():
s_final = torch.round(mu).cpu().numpy().astype(int)
if formulation == "C" and II:
s_final = self._legalize_modulo(s_final, II)
else:
s_final = self._legalize_regular(s_final)
final_violations = self._count_violations(s_final, II)
# Use best if final has more violations
if best_schedule is not None and final_violations > best_violations:
s_final = best_schedule
final_violations = best_violations
solve_time = time.time() - start_time
schedule = {i: int(s_final[i]) for i in range(self.N)}
if verbose:
print(f"\n DONE in {solve_time:.2f}s, {max_iters} iterations")
print(f" Final schedule: {self._format_schedule(schedule)}")
print(f" Violations: {final_violations}")
if II: print(f" II: {II}")
return GausResult(
schedule=schedule,
initiation_interval=II,
objective_value=best_objective,
num_violations=final_violations,
solve_time_seconds=solve_time,
iterations=max_iters,
loss_history=loss_history,
node_names=self.graph.node_names,
)
def _format_schedule(self, schedule: Dict[int, int]) -> str:
names = self.graph.node_names
return ", ".join(f"{names[i]}@{t}" for i, t in sorted(schedule.items(), key=lambda x: x[1]))
# ============================================================
# Integration with Twill's DependenceGraph
# ============================================================
def twill_graph_to_gaus(
twill_graph, # twill.graph.DependenceGraph
D: Optional[int] = None,
) -> Tuple[GausGraph, Dict[str, int]]:
"""Convert a Twill DependenceGraph to GauS format.
Args:
twill_graph: Twill DependenceGraph object
D: Max depth (auto-computed if None)
Returns:
(GausGraph, name_to_index mapping)
"""
V = twill_graph.V
E = twill_graph.E
# Build name -> index mapping
name_to_idx = {v.name: i for i, v in enumerate(V)}
# Separate forward edges (δ=0) from back-edges (δ>0)
forward_edges = []
back_edges = []
for e in E:
src_idx = name_to_idx[e.src]
dst_idx = name_to_idx[e.dst]
if e.iteration_delay == 0:
forward_edges.append((src_idx, dst_idx))
else:
# Back-edge: (src, dst, iteration_delay)
back_edges.append((src_idx, dst_idx, e.iteration_delay))
# Resource weights: sum of RRT usage across all FUs for each instruction
resource_weights = np.array([v.rrt.sum() for v in V], dtype=np.float64)
# Per-FU resource weights (for modulo scheduling, use the dominant FU)
memory_weights = np.array([
sum(v.memory_footprint.values()) if v.memory_footprint else 0
for v in V
], dtype=np.float64)
node_names = [v.name for v in V]
# Auto-compute D if not given
if D is None:
D = sum(v.cycles for v in V) + len(V)
graph = GausGraph(
num_nodes=len(V),
edges=forward_edges,
back_edges=back_edges,
resource_weights=resource_weights,
memory_weights=memory_weights,
node_names=node_names,
)
return graph, name_to_idx
def gaus_solve_twill_graph(
twill_graph, # twill.graph.DependenceGraph
target_II: Optional[int] = None,
D: Optional[int] = None,
max_iters: int = 3000,
verbose: bool = True,
) -> GausResult:
"""Convenience: solve a Twill DependenceGraph using GauS.
If target_II is given, uses modulo scheduling (Formulation C).
Otherwise, uses regular scheduling (Formulation A).
Args:
twill_graph: Twill DependenceGraph
target_II: Initiation interval for modulo scheduling
D: Max depth (auto if None)
max_iters: Max optimization iterations
verbose: Print progress
Returns:
GausResult with named schedule
"""
gaus_graph, name_to_idx = twill_graph_to_gaus(twill_graph, D)
solver = GauSSolver(gaus_graph, D=D or (sum(v.cycles for v in twill_graph.V) + len(twill_graph.V)))
if target_II is not None:
R_cap = 1.0 # Default: 1 resource per modulo slot (matches Twill's capacity=1)
result = solver.solve_modulo(
II=target_II,
R_cap=R_cap,
max_iters=max_iters,
verbose=verbose,
)
else:
result = solver.solve_regular(
max_iters=max_iters,
verbose=verbose,
)
return result
# ============================================================
# Synthetic graph generation for scalability testing
# ============================================================
def generate_random_dag(
num_nodes: int,
edge_density: float = 0.3,
max_weight: int = 3,
num_back_edges: int = 0,
seed: int = 42,
) -> GausGraph:
"""Generate a random DAG for benchmarking.
Args:
num_nodes: Number of operators
edge_density: Probability of edge between valid pairs
max_weight: Max resource weight per node
num_back_edges: Number of loop-carried back-edges to add
seed: Random seed
Returns:
GausGraph
"""
rng = np.random.RandomState(seed)
edges = []
for i in range(num_nodes):
for j in range(i + 1, num_nodes):
if rng.random() < edge_density:
edges.append((i, j))
# Ensure connectivity: at least a chain
if not edges:
for i in range(num_nodes - 1):
edges.append((i, i + 1))
resource_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64)
memory_weights = rng.randint(1, max_weight + 1, size=num_nodes).astype(np.float64)
back_edges = []
if num_back_edges > 0:
# Add back-edges from later nodes to earlier nodes
for _ in range(num_back_edges):
src = rng.randint(num_nodes // 2, num_nodes)
dst = rng.randint(0, num_nodes // 2)
k = rng.randint(1, 3)
back_edges.append((src, dst, k))
return GausGraph(
num_nodes=num_nodes,
edges=edges,
back_edges=back_edges,
resource_weights=resource_weights,
memory_weights=memory_weights,
)