twill-swp-ws / twill /graph.py
AshenNav's picture
Upload twill/graph.py
a3e198c verified
"""
Core data structures: Dependence Graph, Instructions, RRTs, Machine Description.
Based on Section 3.1 of the paper:
- G = (V, E) where V = instructions, E = dependence edges
- Each instruction v has an RRT (Resource Reservation Table)
- Each edge (u, v, d, δ) has clock delay d and iteration delay δ
- Machine description D defines functional unit capacities and memory capacities
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Set
import numpy as np
@dataclass
class Instruction:
"""A tile-level instruction in the loop body.
Attributes:
name: Unique identifier for this instruction
rrt: Resource Reservation Table - 2D array [cycle, functional_unit] -> usage count
Each row = a clock cycle of execution, each column = a functional unit type
variable_latency: Whether this instruction has variable latency (e.g., TMA loads)
memory_footprint: Dict mapping memory_space -> bytes used by this instruction's output
streaming: Whether this is a streaming variable-latency op (no incoming data deps)
"""
name: str
rrt: np.ndarray # shape: (cycles, num_functional_units)
variable_latency: bool = False
memory_footprint: Dict[str, int] = field(default_factory=dict)
streaming: bool = False
@property
def cycles(self) -> int:
"""Number of clock cycles this instruction takes."""
return self.rrt.shape[0]
@property
def functional_units_used(self) -> Set[int]:
"""Set of functional unit indices used by this instruction."""
return set(np.where(self.rrt.sum(axis=0) > 0)[0])
def __repr__(self):
return f"Instruction({self.name}, cycles={self.cycles}, var_lat={self.variable_latency})"
@dataclass
class DependenceEdge:
"""A data dependence edge in the loop dependence graph.
From Section 3.1:
(u, v, d, δ): v must be issued at least d cycles after u,
where v is from iteration i and u is from iteration i - δ.
Attributes:
src: Source instruction name
dst: Destination instruction name
delay: Minimum clock cycle delay d (v must start >= d cycles after u starts)
iteration_delay: δ - the iteration distance (0 = same iteration, 1 = loop-carried)
"""
src: str
dst: str
delay: int # d: minimum clock cycles between src issue and dst issue
iteration_delay: int = 0 # δ: iteration distance
def __repr__(self):
return f"Edge({self.src} -> {self.dst}, d={self.delay}, δ={self.iteration_delay})"
@dataclass
class MachineDescription:
"""Description of the target GPU architecture.
Attributes:
name: Architecture name (e.g., "Hopper", "Blackwell")
functional_units: List of functional unit names (e.g., ["TC", "EXP", "TMA"])
capacities: Dict mapping functional_unit_name -> capacity (max simultaneous usage)
memory_spaces: Dict mapping memory_space_name -> capacity in bytes
num_warps: Number of available warps for WS
variable_latency_warp: Index of the warp designated for variable-latency ops
"""
name: str
functional_units: List[str]
capacities: Dict[str, int]
memory_spaces: Dict[str, int] = field(default_factory=dict)
num_warps: int = 4
variable_latency_warp: int = 0 # W_vl
def capacity(self, fu_name: str) -> int:
"""Get capacity of a functional unit by name."""
return self.capacities.get(fu_name, 0)
def fu_index(self, fu_name: str) -> int:
"""Get index of a functional unit by name."""
return self.functional_units.index(fu_name)
@property
def num_functional_units(self) -> int:
return len(self.functional_units)
@property
def capacity_vector(self) -> np.ndarray:
"""Array of capacities indexed by functional unit index."""
return np.array([self.capacities[fu] for fu in self.functional_units])
class DependenceGraph:
"""Loop dependence graph G = (V, E).
This is the primary input to Twill's optimization pipeline.
Usage:
graph = DependenceGraph(machine)
graph.add_instruction(Instruction("S", rrt_s))
graph.add_instruction(Instruction("P", rrt_p))
graph.add_edge(DependenceEdge("S", "P", delay=1))
...
"""
def __init__(self, machine: MachineDescription):
self.machine = machine
self.instructions: Dict[str, Instruction] = {}
self.edges: List[DependenceEdge] = []
self._instruction_order: List[str] = [] # maintain insertion order
def add_instruction(self, instr: Instruction):
"""Add an instruction to the graph."""
assert instr.name not in self.instructions, f"Duplicate instruction: {instr.name}"
assert instr.rrt.shape[1] == self.machine.num_functional_units, \
f"RRT width {instr.rrt.shape[1]} != num_functional_units {self.machine.num_functional_units}"
self.instructions[instr.name] = instr
self._instruction_order.append(instr.name)
def add_edge(self, edge: DependenceEdge):
"""Add a dependence edge to the graph."""
assert edge.src in self.instructions, f"Unknown source: {edge.src}"
assert edge.dst in self.instructions, f"Unknown destination: {edge.dst}"
self.edges.append(edge)
@property
def V(self) -> List[Instruction]:
"""List of instructions in insertion order."""
return [self.instructions[name] for name in self._instruction_order]
@property
def E(self) -> List[DependenceEdge]:
"""List of dependence edges."""
return self.edges
@property
def num_instructions(self) -> int:
return len(self.instructions)
def get_instruction(self, name: str) -> Instruction:
return self.instructions[name]
def outgoing_edges(self, name: str) -> List[DependenceEdge]:
"""Get all edges where name is the source."""
return [e for e in self.edges if e.src == name]
def incoming_edges(self, name: str) -> List[DependenceEdge]:
"""Get all edges where name is the destination."""
return [e for e in self.edges if e.dst == name]
def has_loop_carried_output(self, name: str) -> bool:
"""Check if instruction has any outgoing loop-carried edge (δ > 0)."""
return any(e.iteration_delay > 0 for e in self.outgoing_edges(name))
def get_cycle_counts(self) -> List[int]:
"""Get list of all edge delays (for cost normalization)."""
delays = set()
for instr in self.V:
delays.add(instr.cycles)
for edge in self.edges:
delays.add(edge.delay)
return sorted(delays)
def compute_min_initiation_interval(self) -> int:
"""Compute the resource-constrained lower bound on I.
For each functional unit f:
I >= ceil(sum of RRT usage across all instructions / capacity(f))
"""
min_I = 1
cap_vec = self.machine.capacity_vector
for fu_idx in range(self.machine.num_functional_units):
total_usage = sum(instr.rrt[:, fu_idx].sum() for instr in self.V)
if cap_vec[fu_idx] > 0:
resource_bound = int(np.ceil(total_usage / cap_vec[fu_idx]))
min_I = max(min_I, resource_bound)
# Recurrence-constrained lower bound
for edge in self.edges:
if edge.iteration_delay > 0:
rec_bound = int(np.ceil(edge.delay / edge.iteration_delay))
min_I = max(min_I, rec_bound)
return min_I
def __repr__(self):
return (f"DependenceGraph(|V|={self.num_instructions}, |E|={len(self.edges)}, "
f"machine={self.machine.name})")
# ============================================================
# Pre-defined machine descriptions
# ============================================================
def hopper_machine(
tc_capacity: int = 1,
exp_capacity: int = 1,
tma_capacity: int = 1,
) -> MachineDescription:
"""NVIDIA Hopper (H100) machine description."""
return MachineDescription(
name="Hopper",
functional_units=["TC", "EXP", "TMA"],
capacities={"TC": tc_capacity, "EXP": exp_capacity, "TMA": tma_capacity},
memory_spaces={"SMEM": 228 * 1024, "REGS": 256 * 1024},
num_warps=4,
variable_latency_warp=0,
)
def blackwell_machine(
tc_capacity: int = 1,
exp_capacity: int = 1,
tma_capacity: int = 1,
tmem_capacity: int = 1,
) -> MachineDescription:
"""NVIDIA Blackwell (B200) machine description."""
return MachineDescription(
name="Blackwell",
functional_units=["TC", "EXP", "TMA", "TMEM"],
capacities={"TC": tc_capacity, "EXP": exp_capacity, "TMA": tma_capacity, "TMEM": tmem_capacity},
memory_spaces={"SMEM": 228 * 1024, "REGS": 256 * 1024, "TMEM": 128 * 1024},
num_warps=4,
variable_latency_warp=0,
)
def make_rrt(cycles: int, fu_usage: Dict[int, List[int]], num_fus: int) -> np.ndarray:
"""Create an RRT array."""
rrt = np.zeros((cycles, num_fus), dtype=int)
for fu_idx, usage_per_cycle in fu_usage.items():
for c, usage in enumerate(usage_per_cycle):
if c < cycles:
rrt[c, fu_idx] = usage
return rrt