Upload twill/graph.py
Browse files- twill/graph.py +250 -0
twill/graph.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core data structures: Dependence Graph, Instructions, RRTs, Machine Description.
|
| 3 |
+
|
| 4 |
+
Based on Section 3.1 of the paper:
|
| 5 |
+
- G = (V, E) where V = instructions, E = dependence edges
|
| 6 |
+
- Each instruction v has an RRT (Resource Reservation Table)
|
| 7 |
+
- Each edge (u, v, d, δ) has clock delay d and iteration delay δ
|
| 8 |
+
- Machine description D defines functional unit capacities and memory capacities
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Dict, List, Optional, Tuple, Set
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Instruction:
|
| 18 |
+
"""A tile-level instruction in the loop body.
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
name: Unique identifier for this instruction
|
| 22 |
+
rrt: Resource Reservation Table - 2D array [cycle, functional_unit] -> usage count
|
| 23 |
+
Each row = a clock cycle of execution, each column = a functional unit type
|
| 24 |
+
variable_latency: Whether this instruction has variable latency (e.g., TMA loads)
|
| 25 |
+
memory_footprint: Dict mapping memory_space -> bytes used by this instruction's output
|
| 26 |
+
streaming: Whether this is a streaming variable-latency op (no incoming data deps)
|
| 27 |
+
"""
|
| 28 |
+
name: str
|
| 29 |
+
rrt: np.ndarray # shape: (cycles, num_functional_units)
|
| 30 |
+
variable_latency: bool = False
|
| 31 |
+
memory_footprint: Dict[str, int] = field(default_factory=dict)
|
| 32 |
+
streaming: bool = False
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def cycles(self) -> int:
|
| 36 |
+
"""Number of clock cycles this instruction takes."""
|
| 37 |
+
return self.rrt.shape[0]
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def functional_units_used(self) -> Set[int]:
|
| 41 |
+
"""Set of functional unit indices used by this instruction."""
|
| 42 |
+
return set(np.where(self.rrt.sum(axis=0) > 0)[0])
|
| 43 |
+
|
| 44 |
+
def __repr__(self):
|
| 45 |
+
return f"Instruction({self.name}, cycles={self.cycles}, var_lat={self.variable_latency})"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class DependenceEdge:
|
| 50 |
+
"""A data dependence edge in the loop dependence graph.
|
| 51 |
+
|
| 52 |
+
From Section 3.1:
|
| 53 |
+
(u, v, d, δ): v must be issued at least d cycles after u,
|
| 54 |
+
where v is from iteration i and u is from iteration i - δ.
|
| 55 |
+
|
| 56 |
+
Attributes:
|
| 57 |
+
src: Source instruction name
|
| 58 |
+
dst: Destination instruction name
|
| 59 |
+
delay: Minimum clock cycle delay d (v must start >= d cycles after u starts)
|
| 60 |
+
iteration_delay: δ - the iteration distance (0 = same iteration, 1 = loop-carried)
|
| 61 |
+
"""
|
| 62 |
+
src: str
|
| 63 |
+
dst: str
|
| 64 |
+
delay: int # d: minimum clock cycles between src issue and dst issue
|
| 65 |
+
iteration_delay: int = 0 # δ: iteration distance
|
| 66 |
+
|
| 67 |
+
def __repr__(self):
|
| 68 |
+
return f"Edge({self.src} -> {self.dst}, d={self.delay}, δ={self.iteration_delay})"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class MachineDescription:
|
| 73 |
+
"""Description of the target GPU architecture.
|
| 74 |
+
|
| 75 |
+
Attributes:
|
| 76 |
+
name: Architecture name (e.g., "Hopper", "Blackwell")
|
| 77 |
+
functional_units: List of functional unit names (e.g., ["TC", "EXP", "TMA"])
|
| 78 |
+
capacities: Dict mapping functional_unit_name -> capacity (max simultaneous usage)
|
| 79 |
+
memory_spaces: Dict mapping memory_space_name -> capacity in bytes
|
| 80 |
+
num_warps: Number of available warps for WS
|
| 81 |
+
variable_latency_warp: Index of the warp designated for variable-latency ops
|
| 82 |
+
"""
|
| 83 |
+
name: str
|
| 84 |
+
functional_units: List[str]
|
| 85 |
+
capacities: Dict[str, int]
|
| 86 |
+
memory_spaces: Dict[str, int] = field(default_factory=dict)
|
| 87 |
+
num_warps: int = 4
|
| 88 |
+
variable_latency_warp: int = 0 # W_vl
|
| 89 |
+
|
| 90 |
+
def capacity(self, fu_name: str) -> int:
|
| 91 |
+
"""Get capacity of a functional unit by name."""
|
| 92 |
+
return self.capacities.get(fu_name, 0)
|
| 93 |
+
|
| 94 |
+
def fu_index(self, fu_name: str) -> int:
|
| 95 |
+
"""Get index of a functional unit by name."""
|
| 96 |
+
return self.functional_units.index(fu_name)
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def num_functional_units(self) -> int:
|
| 100 |
+
return len(self.functional_units)
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def capacity_vector(self) -> np.ndarray:
|
| 104 |
+
"""Array of capacities indexed by functional unit index."""
|
| 105 |
+
return np.array([self.capacities[fu] for fu in self.functional_units])
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class DependenceGraph:
|
| 109 |
+
"""Loop dependence graph G = (V, E).
|
| 110 |
+
|
| 111 |
+
This is the primary input to Twill's optimization pipeline.
|
| 112 |
+
|
| 113 |
+
Usage:
|
| 114 |
+
graph = DependenceGraph(machine)
|
| 115 |
+
graph.add_instruction(Instruction("S", rrt_s))
|
| 116 |
+
graph.add_instruction(Instruction("P", rrt_p))
|
| 117 |
+
graph.add_edge(DependenceEdge("S", "P", delay=1))
|
| 118 |
+
...
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, machine: MachineDescription):
|
| 122 |
+
self.machine = machine
|
| 123 |
+
self.instructions: Dict[str, Instruction] = {}
|
| 124 |
+
self.edges: List[DependenceEdge] = []
|
| 125 |
+
self._instruction_order: List[str] = [] # maintain insertion order
|
| 126 |
+
|
| 127 |
+
def add_instruction(self, instr: Instruction):
|
| 128 |
+
"""Add an instruction to the graph."""
|
| 129 |
+
assert instr.name not in self.instructions, f"Duplicate instruction: {instr.name}"
|
| 130 |
+
assert instr.rrt.shape[1] == self.machine.num_functional_units, \
|
| 131 |
+
f"RRT width {instr.rrt.shape[1]} != num_functional_units {self.machine.num_functional_units}"
|
| 132 |
+
self.instructions[instr.name] = instr
|
| 133 |
+
self._instruction_order.append(instr.name)
|
| 134 |
+
|
| 135 |
+
def add_edge(self, edge: DependenceEdge):
|
| 136 |
+
"""Add a dependence edge to the graph."""
|
| 137 |
+
assert edge.src in self.instructions, f"Unknown source: {edge.src}"
|
| 138 |
+
assert edge.dst in self.instructions, f"Unknown destination: {edge.dst}"
|
| 139 |
+
self.edges.append(edge)
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def V(self) -> List[Instruction]:
|
| 143 |
+
"""List of instructions in insertion order."""
|
| 144 |
+
return [self.instructions[name] for name in self._instruction_order]
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def E(self) -> List[DependenceEdge]:
|
| 148 |
+
"""List of dependence edges."""
|
| 149 |
+
return self.edges
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def num_instructions(self) -> int:
|
| 153 |
+
return len(self.instructions)
|
| 154 |
+
|
| 155 |
+
def get_instruction(self, name: str) -> Instruction:
|
| 156 |
+
return self.instructions[name]
|
| 157 |
+
|
| 158 |
+
def outgoing_edges(self, name: str) -> List[DependenceEdge]:
|
| 159 |
+
"""Get all edges where name is the source."""
|
| 160 |
+
return [e for e in self.edges if e.src == name]
|
| 161 |
+
|
| 162 |
+
def incoming_edges(self, name: str) -> List[DependenceEdge]:
|
| 163 |
+
"""Get all edges where name is the destination."""
|
| 164 |
+
return [e for e in self.edges if e.dst == name]
|
| 165 |
+
|
| 166 |
+
def has_loop_carried_output(self, name: str) -> bool:
|
| 167 |
+
"""Check if instruction has any outgoing loop-carried edge (δ > 0)."""
|
| 168 |
+
return any(e.iteration_delay > 0 for e in self.outgoing_edges(name))
|
| 169 |
+
|
| 170 |
+
def get_cycle_counts(self) -> List[int]:
|
| 171 |
+
"""Get list of all edge delays (for cost normalization)."""
|
| 172 |
+
delays = set()
|
| 173 |
+
for instr in self.V:
|
| 174 |
+
delays.add(instr.cycles)
|
| 175 |
+
for edge in self.edges:
|
| 176 |
+
delays.add(edge.delay)
|
| 177 |
+
return sorted(delays)
|
| 178 |
+
|
| 179 |
+
def compute_min_initiation_interval(self) -> int:
|
| 180 |
+
"""Compute the resource-constrained lower bound on I.
|
| 181 |
+
|
| 182 |
+
For each functional unit f:
|
| 183 |
+
I >= ceil(sum of RRT usage across all instructions / capacity(f))
|
| 184 |
+
"""
|
| 185 |
+
min_I = 1
|
| 186 |
+
cap_vec = self.machine.capacity_vector
|
| 187 |
+
for fu_idx in range(self.machine.num_functional_units):
|
| 188 |
+
total_usage = sum(instr.rrt[:, fu_idx].sum() for instr in self.V)
|
| 189 |
+
if cap_vec[fu_idx] > 0:
|
| 190 |
+
resource_bound = int(np.ceil(total_usage / cap_vec[fu_idx]))
|
| 191 |
+
min_I = max(min_I, resource_bound)
|
| 192 |
+
|
| 193 |
+
# Recurrence-constrained lower bound
|
| 194 |
+
for edge in self.edges:
|
| 195 |
+
if edge.iteration_delay > 0:
|
| 196 |
+
rec_bound = int(np.ceil(edge.delay / edge.iteration_delay))
|
| 197 |
+
min_I = max(min_I, rec_bound)
|
| 198 |
+
|
| 199 |
+
return min_I
|
| 200 |
+
|
| 201 |
+
def __repr__(self):
|
| 202 |
+
return (f"DependenceGraph(|V|={self.num_instructions}, |E|={len(self.edges)}, "
|
| 203 |
+
f"machine={self.machine.name})")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ============================================================
|
| 207 |
+
# Pre-defined machine descriptions
|
| 208 |
+
# ============================================================
|
| 209 |
+
|
| 210 |
+
def hopper_machine(
|
| 211 |
+
tc_capacity: int = 1,
|
| 212 |
+
exp_capacity: int = 1,
|
| 213 |
+
tma_capacity: int = 1,
|
| 214 |
+
) -> MachineDescription:
|
| 215 |
+
"""NVIDIA Hopper (H100) machine description."""
|
| 216 |
+
return MachineDescription(
|
| 217 |
+
name="Hopper",
|
| 218 |
+
functional_units=["TC", "EXP", "TMA"],
|
| 219 |
+
capacities={"TC": tc_capacity, "EXP": exp_capacity, "TMA": tma_capacity},
|
| 220 |
+
memory_spaces={"SMEM": 228 * 1024, "REGS": 256 * 1024},
|
| 221 |
+
num_warps=4,
|
| 222 |
+
variable_latency_warp=0,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def blackwell_machine(
|
| 227 |
+
tc_capacity: int = 1,
|
| 228 |
+
exp_capacity: int = 1,
|
| 229 |
+
tma_capacity: int = 1,
|
| 230 |
+
tmem_capacity: int = 1,
|
| 231 |
+
) -> MachineDescription:
|
| 232 |
+
"""NVIDIA Blackwell (B200) machine description."""
|
| 233 |
+
return MachineDescription(
|
| 234 |
+
name="Blackwell",
|
| 235 |
+
functional_units=["TC", "EXP", "TMA", "TMEM"],
|
| 236 |
+
capacities={"TC": tc_capacity, "EXP": exp_capacity, "TMA": tma_capacity, "TMEM": tmem_capacity},
|
| 237 |
+
memory_spaces={"SMEM": 228 * 1024, "REGS": 256 * 1024, "TMEM": 128 * 1024},
|
| 238 |
+
num_warps=4,
|
| 239 |
+
variable_latency_warp=0,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def make_rrt(cycles: int, fu_usage: Dict[int, List[int]], num_fus: int) -> np.ndarray:
|
| 244 |
+
"""Create an RRT array."""
|
| 245 |
+
rrt = np.zeros((cycles, num_fus), dtype=int)
|
| 246 |
+
for fu_idx, usage_per_cycle in fu_usage.items():
|
| 247 |
+
for c, usage in enumerate(usage_per_cycle):
|
| 248 |
+
if c < cycles:
|
| 249 |
+
rrt[c, fu_idx] = usage
|
| 250 |
+
return rrt
|