File size: 9,378 Bytes
a3e198c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
Core data structures: Dependence Graph, Instructions, RRTs, Machine Description.

Based on Section 3.1 of the paper:
- G = (V, E) where V = instructions, E = dependence edges
- Each instruction v has an RRT (Resource Reservation Table)
- Each edge (u, v, d, δ) has clock delay d and iteration delay δ
- Machine description D defines functional unit capacities and memory capacities
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Set
import numpy as np


@dataclass
class Instruction:
    """A tile-level instruction in the loop body.
    
    Attributes:
        name: Unique identifier for this instruction
        rrt: Resource Reservation Table - 2D array [cycle, functional_unit] -> usage count
             Each row = a clock cycle of execution, each column = a functional unit type
        variable_latency: Whether this instruction has variable latency (e.g., TMA loads)
        memory_footprint: Dict mapping memory_space -> bytes used by this instruction's output
        streaming: Whether this is a streaming variable-latency op (no incoming data deps)
    """
    name: str
    rrt: np.ndarray  # shape: (cycles, num_functional_units)
    variable_latency: bool = False
    memory_footprint: Dict[str, int] = field(default_factory=dict)
    streaming: bool = False

    @property
    def cycles(self) -> int:
        """Number of clock cycles this instruction takes."""
        return self.rrt.shape[0]

    @property
    def functional_units_used(self) -> Set[int]:
        """Set of functional unit indices used by this instruction."""
        return set(np.where(self.rrt.sum(axis=0) > 0)[0])

    def __repr__(self):
        return f"Instruction({self.name}, cycles={self.cycles}, var_lat={self.variable_latency})"


@dataclass
class DependenceEdge:
    """A data dependence edge in the loop dependence graph.
    
    From Section 3.1:
        (u, v, d, δ): v must be issued at least d cycles after u, 
        where v is from iteration i and u is from iteration i - δ.
        
    Attributes:
        src: Source instruction name
        dst: Destination instruction name
        delay: Minimum clock cycle delay d (v must start >= d cycles after u starts)
        iteration_delay: δ - the iteration distance (0 = same iteration, 1 = loop-carried)
    """
    src: str
    dst: str
    delay: int  # d: minimum clock cycles between src issue and dst issue
    iteration_delay: int = 0  # δ: iteration distance

    def __repr__(self):
        return f"Edge({self.src} -> {self.dst}, d={self.delay}, δ={self.iteration_delay})"


@dataclass
class MachineDescription:
    """Description of the target GPU architecture.
    
    Attributes:
        name: Architecture name (e.g., "Hopper", "Blackwell")
        functional_units: List of functional unit names (e.g., ["TC", "EXP", "TMA"])
        capacities: Dict mapping functional_unit_name -> capacity (max simultaneous usage)
        memory_spaces: Dict mapping memory_space_name -> capacity in bytes
        num_warps: Number of available warps for WS
        variable_latency_warp: Index of the warp designated for variable-latency ops
    """
    name: str
    functional_units: List[str]
    capacities: Dict[str, int]
    memory_spaces: Dict[str, int] = field(default_factory=dict)
    num_warps: int = 4
    variable_latency_warp: int = 0  # W_vl

    def capacity(self, fu_name: str) -> int:
        """Get capacity of a functional unit by name."""
        return self.capacities.get(fu_name, 0)

    def fu_index(self, fu_name: str) -> int:
        """Get index of a functional unit by name."""
        return self.functional_units.index(fu_name)

    @property
    def num_functional_units(self) -> int:
        return len(self.functional_units)

    @property
    def capacity_vector(self) -> np.ndarray:
        """Array of capacities indexed by functional unit index."""
        return np.array([self.capacities[fu] for fu in self.functional_units])


class DependenceGraph:
    """Loop dependence graph G = (V, E).
    
    This is the primary input to Twill's optimization pipeline.
    
    Usage:
        graph = DependenceGraph(machine)
        graph.add_instruction(Instruction("S", rrt_s))
        graph.add_instruction(Instruction("P", rrt_p))
        graph.add_edge(DependenceEdge("S", "P", delay=1))
        ...
    """

    def __init__(self, machine: MachineDescription):
        self.machine = machine
        self.instructions: Dict[str, Instruction] = {}
        self.edges: List[DependenceEdge] = []
        self._instruction_order: List[str] = []  # maintain insertion order

    def add_instruction(self, instr: Instruction):
        """Add an instruction to the graph."""
        assert instr.name not in self.instructions, f"Duplicate instruction: {instr.name}"
        assert instr.rrt.shape[1] == self.machine.num_functional_units, \
            f"RRT width {instr.rrt.shape[1]} != num_functional_units {self.machine.num_functional_units}"
        self.instructions[instr.name] = instr
        self._instruction_order.append(instr.name)

    def add_edge(self, edge: DependenceEdge):
        """Add a dependence edge to the graph."""
        assert edge.src in self.instructions, f"Unknown source: {edge.src}"
        assert edge.dst in self.instructions, f"Unknown destination: {edge.dst}"
        self.edges.append(edge)

    @property
    def V(self) -> List[Instruction]:
        """List of instructions in insertion order."""
        return [self.instructions[name] for name in self._instruction_order]

    @property 
    def E(self) -> List[DependenceEdge]:
        """List of dependence edges."""
        return self.edges

    @property
    def num_instructions(self) -> int:
        return len(self.instructions)

    def get_instruction(self, name: str) -> Instruction:
        return self.instructions[name]

    def outgoing_edges(self, name: str) -> List[DependenceEdge]:
        """Get all edges where name is the source."""
        return [e for e in self.edges if e.src == name]

    def incoming_edges(self, name: str) -> List[DependenceEdge]:
        """Get all edges where name is the destination."""
        return [e for e in self.edges if e.dst == name]

    def has_loop_carried_output(self, name: str) -> bool:
        """Check if instruction has any outgoing loop-carried edge (δ > 0)."""
        return any(e.iteration_delay > 0 for e in self.outgoing_edges(name))

    def get_cycle_counts(self) -> List[int]:
        """Get list of all edge delays (for cost normalization)."""
        delays = set()
        for instr in self.V:
            delays.add(instr.cycles)
        for edge in self.edges:
            delays.add(edge.delay)
        return sorted(delays)

    def compute_min_initiation_interval(self) -> int:
        """Compute the resource-constrained lower bound on I.
        
        For each functional unit f:
            I >= ceil(sum of RRT usage across all instructions / capacity(f))
        """
        min_I = 1
        cap_vec = self.machine.capacity_vector
        for fu_idx in range(self.machine.num_functional_units):
            total_usage = sum(instr.rrt[:, fu_idx].sum() for instr in self.V)
            if cap_vec[fu_idx] > 0:
                resource_bound = int(np.ceil(total_usage / cap_vec[fu_idx]))
                min_I = max(min_I, resource_bound)

        # Recurrence-constrained lower bound
        for edge in self.edges:
            if edge.iteration_delay > 0:
                rec_bound = int(np.ceil(edge.delay / edge.iteration_delay))
                min_I = max(min_I, rec_bound)

        return min_I

    def __repr__(self):
        return (f"DependenceGraph(|V|={self.num_instructions}, |E|={len(self.edges)}, "
                f"machine={self.machine.name})")


# ============================================================
# Pre-defined machine descriptions
# ============================================================

def hopper_machine(
    tc_capacity: int = 1,
    exp_capacity: int = 1,
    tma_capacity: int = 1,
) -> MachineDescription:
    """NVIDIA Hopper (H100) machine description."""
    return MachineDescription(
        name="Hopper",
        functional_units=["TC", "EXP", "TMA"],
        capacities={"TC": tc_capacity, "EXP": exp_capacity, "TMA": tma_capacity},
        memory_spaces={"SMEM": 228 * 1024, "REGS": 256 * 1024},
        num_warps=4,
        variable_latency_warp=0,
    )


def blackwell_machine(
    tc_capacity: int = 1,
    exp_capacity: int = 1,
    tma_capacity: int = 1,
    tmem_capacity: int = 1,
) -> MachineDescription:
    """NVIDIA Blackwell (B200) machine description."""
    return MachineDescription(
        name="Blackwell",
        functional_units=["TC", "EXP", "TMA", "TMEM"],
        capacities={"TC": tc_capacity, "EXP": exp_capacity, "TMA": tma_capacity, "TMEM": tmem_capacity},
        memory_spaces={"SMEM": 228 * 1024, "REGS": 256 * 1024, "TMEM": 128 * 1024},
        num_warps=4,
        variable_latency_warp=0,
    )


def make_rrt(cycles: int, fu_usage: Dict[int, List[int]], num_fus: int) -> np.ndarray:
    """Create an RRT array."""
    rrt = np.zeros((cycles, num_fus), dtype=int)
    for fu_idx, usage_per_cycle in fu_usage.items():
        for c, usage in enumerate(usage_per_cycle):
            if c < cycles:
                rrt[c, fu_idx] = usage
    return rrt