twill-swp-ws / twill /modulo_scheduler.py
AshenNav's picture
Upload twill/modulo_scheduler.py with huggingface_hub
b84ba2b verified
"""
Phase 1: Optimal Modulo Scheduling via Integer Linear Programming (ZLP).
Based on Section 3.1, 4.1, and 5.1 of the paper.
Uses the ILP formulation from Stoutchinin et al. (referenced as [stoutchinin-ilp]).
The modulo scheduling problem:
Given G = (V, E) and target initiation interval I,
find M: V -> [0, L) such that:
1. Dependence: ∀(u,v,d,δ)∈E: M(v) - M(u) + I·δ ≥ d
2. Resource: modular RRT fits within machine capacities
3. Minimize L (schedule length) subject to the above
Uses CBC solver via PuLP.
"""
import pulp
import numpy as np
from typing import Dict, List, Optional, Tuple
from twill.graph import DependenceGraph, Instruction, DependenceEdge
class ModuloScheduleResult:
"""Result of modulo scheduling.
Attributes:
schedule: Dict mapping instruction name -> clock cycle M(v)
initiation_interval: I
length: L (total schedule length)
num_copies: ceil(L/I) - number of overlapping iterations
"""
def __init__(self, schedule: Dict[str, int], I: int):
self.schedule = schedule
self.initiation_interval = I
self._length = None
@property
def I(self) -> int:
return self.initiation_interval
@property
def length(self) -> int:
"""L: total schedule length (max M(v) + cycles(v) across all instructions)."""
if self._length is not None:
return self._length
return max(self.schedule.values()) + 1 # +1 because 0-indexed
@length.setter
def length(self, val: int):
self._length = val
@property
def num_copies(self) -> int:
"""ceil(L/I) - number of overlapping iteration copies."""
return int(np.ceil(self.length / self.I))
def __repr__(self):
return (f"ModuloSchedule(I={self.I}, L={self.length}, copies={self.num_copies}, "
f"schedule={self.schedule})")
def optimal_modulo_schedule(
graph: DependenceGraph,
target_I: int,
solver_time_limit: int = 120,
verbose: bool = False,
) -> Optional[ModuloScheduleResult]:
"""Find an optimal modulo schedule with the given initiation interval.
Uses ILP formulation: minimize L subject to dependence and resource constraints.
Args:
graph: The loop dependence graph
target_I: Target initiation interval
solver_time_limit: Time limit for the solver in seconds
verbose: Print solver output
Returns:
ModuloScheduleResult if feasible, None if infeasible for this I
"""
I = target_I
V = graph.V
E = graph.E
machine = graph.machine
n = len(V)
# Variable: M(v) for each instruction v - the clock cycle it's scheduled at
prob = pulp.LpProblem(f"ModuloSchedule_I{I}", pulp.LpMinimize)
# Decision variables: M[v] ∈ [0, big_M)
# Upper bound on schedule length: heuristic
max_cycles = max(v.cycles for v in V)
big_M = I * (n + 1) * max_cycles # generous upper bound
M = {}
for v in V:
M[v.name] = pulp.LpVariable(f"M_{v.name}", lowBound=0, upBound=big_M, cat='Integer')
# Auxiliary variable for schedule length L = max(M(v) + cycles(v))
L = pulp.LpVariable("L", lowBound=1, upBound=big_M, cat='Integer')
# Objective: minimize L
prob += L
# Constraint: L >= M(v) + cycles(v) for all v
for v in V:
prob += L >= M[v.name] + v.cycles
# Dependence constraints (Section 3.1):
# ∀(u,v,d,δ)∈E: M(v) + I·δ ≥ M(u) + d
# => M(v) - M(u) ≥ d - I·δ
for e in E:
prob += M[e.dst] - M[e.src] >= e.delay - I * e.iteration_delay
# Resource constraints (modular):
# For each functional unit f and each time slot t ∈ [0, I):
# Σ_{v} (number of cycles in [0, cycles(v)) where v occupies f at (M(v)+c) mod I == t) ≤ cap(f)
#
# This is the standard modular resource constraint.
# Since M(v) is a variable, we can't directly encode this as a linear constraint.
# Instead, we use a linearization with binary indicator variables.
# For each instruction v, we introduce binary variables slot[v,s] indicating
# M(v) mod I == s
slot = {}
for v in V:
for s in range(I):
slot[v.name, s] = pulp.LpVariable(f"slot_{v.name}_{s}", cat='Binary')
# Exactly one slot
prob += pulp.lpSum(slot[v.name, s] for s in range(I)) == 1
# Link M(v) to slot: M(v) = q*I + s for some integer q
q_v = pulp.LpVariable(f"q_{v.name}", lowBound=0, upBound=big_M // max(I, 1), cat='Integer')
prob += M[v.name] == q_v * I + pulp.lpSum(s * slot[v.name, s] for s in range(I))
# Modular resource constraint:
# For each time slot t ∈ [0, I) and functional unit f:
# Σ_{v ∈ V} Σ_{c ∈ [0, cycles(v))} RRT[v][c, f] * slot[v, (t-c) mod I] ≤ cap(f)
for t in range(I):
for f_idx, f_name in enumerate(machine.functional_units):
cap = machine.capacity(f_name)
if cap <= 0:
continue
terms = []
for v in V:
for c in range(v.cycles):
usage = int(v.rrt[c, f_idx])
if usage > 0:
s = (t - c) % I
terms.append(usage * slot[v.name, s])
if terms:
prob += pulp.lpSum(terms) <= cap
# Solve
solver = pulp.PULP_CBC_CMD(msg=1 if verbose else 0, timeLimit=solver_time_limit)
status = prob.solve(solver)
if status != pulp.constants.LpStatusOptimal:
return None
# Extract solution
schedule = {}
for v in V:
schedule[v.name] = int(round(pulp.value(M[v.name])))
result = ModuloScheduleResult(schedule, I)
result.length = int(round(pulp.value(L)))
return result
def compute_modular_rrt(
graph: DependenceGraph,
schedule: ModuloScheduleResult,
) -> np.ndarray:
"""Compute the modular RRT for a given schedule.
The modular RRT shows resource usage per time slot in [0, I).
modular_rrt[t, f] = total usage of functional unit f at time slot t in steady state.
Returns:
np.ndarray of shape (I, num_functional_units)
"""
I = schedule.I
num_fus = graph.machine.num_functional_units
mod_rrt = np.zeros((I, num_fus), dtype=int)
for v in graph.V:
m_v = schedule.schedule[v.name]
for c in range(v.cycles):
t = (m_v + c) % I
for f in range(num_fus):
mod_rrt[t, f] += int(v.rrt[c, f])
return mod_rrt
def validate_schedule(
graph: DependenceGraph,
schedule: ModuloScheduleResult,
) -> Tuple[bool, List[str]]:
"""Validate that a modulo schedule satisfies all constraints.
Returns:
Tuple of (is_valid, list of violation messages)
"""
violations = []
I = schedule.I
M = schedule.schedule
# Check dependence constraints
for e in graph.E:
m_src = M[e.src]
m_dst = M[e.dst]
if m_dst + I * e.iteration_delay < m_src + e.delay:
violations.append(
f"Dependence violation: {e.src}({m_src}) -> {e.dst}({m_dst}), "
f"need M({e.dst}) + I*{e.iteration_delay} >= M({e.src}) + {e.delay}, "
f"got {m_dst + I * e.iteration_delay} < {m_src + e.delay}"
)
# Check resource constraints
mod_rrt = compute_modular_rrt(graph, schedule)
cap_vec = graph.machine.capacity_vector
for t in range(I):
for f in range(graph.machine.num_functional_units):
if mod_rrt[t, f] > cap_vec[f]:
violations.append(
f"Resource violation at t={t}, {graph.machine.functional_units[f]}: "
f"usage={mod_rrt[t, f]} > capacity={cap_vec[f]}"
)
return len(violations) == 0, violations