AshenNav commited on
Commit
6b0a12a
Β·
verified Β·
1 Parent(s): cb0545b

Upload twill/visualization.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. twill/visualization.py +253 -0
twill/visualization.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization: Generate schedule diagrams and warp assignment views.
3
+
4
+ Based on the figures in the paper (Figures 1, 3, 7, 9).
5
+ """
6
+
7
+ import numpy as np
8
+ from typing import Dict, List, Optional
9
+ from twill.graph import DependenceGraph
10
+ from twill.smt_joint import JointSWPWSResult
11
+ from twill.modulo_scheduler import ModuloScheduleResult
12
+
13
+
14
+ # Color palette for different instruction types / warps
15
+ WARP_COLORS = [
16
+ '#4CAF50', # Green - variable latency / producer
17
+ '#E91E63', # Pink - compute / TC
18
+ '#2196F3', # Blue - compute / EXP
19
+ '#FF9800', # Orange - compute / misc
20
+ '#9C27B0', # Purple
21
+ '#00BCD4', # Cyan
22
+ '#795548', # Brown
23
+ '#607D8B', # Blue-grey
24
+ ]
25
+
26
+ FU_COLORS = {
27
+ 'TC': '#E91E63', # Pink for Tensor Core
28
+ 'EXP': '#2196F3', # Blue for Exponential
29
+ 'TMA': '#4CAF50', # Green for TMA loads
30
+ 'TMEM': '#FF9800', # Orange for Tensor Memory
31
+ }
32
+
33
+
34
+ def visualize_schedule(
35
+ graph: DependenceGraph,
36
+ result: JointSWPWSResult,
37
+ output_path: Optional[str] = None,
38
+ title: str = "Twill Schedule",
39
+ ) -> str:
40
+ """Generate a text-based visualization of the schedule.
41
+
42
+ Shows a timeline with instructions placed at their scheduled cycles,
43
+ colored by warp assignment (in the text representation, shown as markers).
44
+
45
+ Args:
46
+ graph: The dependence graph
47
+ result: The joint SWP+WS result
48
+ output_path: If provided, also generate a matplotlib figure
49
+ title: Title for the visualization
50
+
51
+ Returns:
52
+ String representation of the schedule
53
+ """
54
+ I = result.I
55
+ L = result.length
56
+ n_copies = result.num_copies
57
+ M = result.schedule
58
+ wa = result.warp_assignment
59
+
60
+ lines = []
61
+ lines.append(f"β•”{'═' * 60}β•—")
62
+ lines.append(f"β•‘ {title:^58s} β•‘")
63
+ lines.append(f"β•‘ I={I}, L={L}, copies={n_copies}{' ' * (58 - len(f'I={I}, L={L}, copies={n_copies}'))}β•‘")
64
+ lines.append(f"β• {'═' * 60}β•£")
65
+
66
+ # Header: functional units
67
+ fu_names = graph.machine.functional_units
68
+ header = "Cycle β”‚"
69
+ for fu in fu_names:
70
+ header += f" {fu:^8s} β”‚"
71
+ header += " Warp "
72
+ lines.append(f"β•‘ {header:<58s} β•‘")
73
+ lines.append(f"β•‘ {'─' * (len(header)):^58s} β•‘")
74
+
75
+ # Build timeline
76
+ for t in range(L):
77
+ # Find what's scheduled at this cycle
78
+ active_ops = []
79
+ for v in graph.V:
80
+ for i in range(n_copies):
81
+ abs_time = M[v.name] + i * I
82
+ if abs_time == t:
83
+ active_ops.append((v, i))
84
+
85
+ if active_ops:
86
+ for v, i in active_ops:
87
+ warp = wa.warp_of(v.name)
88
+ # Build functional unit usage string
89
+ fu_str = f" {t:3d} β”‚"
90
+ for f_idx in range(len(fu_names)):
91
+ usage = int(v.rrt[:, f_idx].sum())
92
+ if usage > 0:
93
+ fu_str += f" {v.name:^8s} β”‚"
94
+ else:
95
+ fu_str += f" {'Β·':^8s} β”‚"
96
+ fu_str += f" W{warp} "
97
+ if i > 0:
98
+ fu_str += f"(i+{i})"
99
+ lines.append(f"β•‘ {fu_str:<58s} β•‘")
100
+ else:
101
+ fu_str = f" {t:3d} β”‚"
102
+ for _ in fu_names:
103
+ fu_str += f" {'Β·':^8s} β”‚"
104
+ fu_str += " "
105
+ lines.append(f"β•‘ {fu_str:<58s} β•‘")
106
+
107
+ lines.append(f"β• {'═' * 60}β•£")
108
+
109
+ # Warp assignment summary
110
+ lines.append(f"β•‘ {'Warp Assignments:':^58s} β•‘")
111
+ for w in range(graph.machine.num_warps):
112
+ instrs = wa.instructions_on_warp(w)
113
+ if instrs:
114
+ label = wa.warp_names.get(w, f"Warp {w}")
115
+ instr_str = f" {label}: {', '.join(instrs)}"
116
+ lines.append(f"β•‘ {instr_str:<58s} β•‘")
117
+
118
+ # Cross-warp barriers
119
+ barriers = []
120
+ for edge in graph.E:
121
+ src_warp = wa.warp_of(edge.src)
122
+ dst_warp = wa.warp_of(edge.dst)
123
+ if src_warp != dst_warp:
124
+ barriers.append(f" {edge.src}(W{src_warp}) β†’ {edge.dst}(W{dst_warp})")
125
+
126
+ if barriers:
127
+ lines.append(f"β•‘ {'':^58s} β•‘")
128
+ lines.append(f"β•‘ {'Cross-Warp Barriers:':^58s} β•‘")
129
+ for b in barriers:
130
+ lines.append(f"β•‘ {b:<58s} β•‘")
131
+
132
+ lines.append(f"β•š{'═' * 60}╝")
133
+
134
+ text_viz = "\n".join(lines)
135
+
136
+ # Optionally generate matplotlib figure
137
+ if output_path:
138
+ _generate_matplotlib_figure(graph, result, output_path, title)
139
+
140
+ return text_viz
141
+
142
+
143
+ def _generate_matplotlib_figure(
144
+ graph: DependenceGraph,
145
+ result: JointSWPWSResult,
146
+ output_path: str,
147
+ title: str,
148
+ ):
149
+ """Generate a matplotlib figure of the schedule (Gantt chart style)."""
150
+ try:
151
+ import matplotlib.pyplot as plt
152
+ import matplotlib.patches as mpatches
153
+ except ImportError:
154
+ print("matplotlib not available for figure generation")
155
+ return
156
+
157
+ I = result.I
158
+ L = result.length
159
+ n_copies = result.num_copies
160
+ M = result.schedule
161
+ wa = result.warp_assignment
162
+ machine = graph.machine
163
+
164
+ fig, ax = plt.subplots(1, 1, figsize=(14, max(6, L * 0.4)))
165
+
166
+ # Y-axis: time (cycles), X-axis: functional units
167
+ fu_names = machine.functional_units
168
+ n_fus = len(fu_names)
169
+ bar_width = 0.8
170
+
171
+ for v in graph.V:
172
+ warp = wa.warp_of(v.name)
173
+ color = WARP_COLORS[warp % len(WARP_COLORS)]
174
+
175
+ for i in range(n_copies):
176
+ abs_time = M[v.name] + i * I
177
+ if abs_time >= L:
178
+ continue
179
+
180
+ for c in range(v.cycles):
181
+ for f_idx in range(n_fus):
182
+ if v.rrt[c, f_idx] > 0:
183
+ rect = mpatches.FancyBboxPatch(
184
+ (f_idx - bar_width / 2, abs_time + c),
185
+ bar_width, 1,
186
+ boxstyle="round,pad=0.05",
187
+ facecolor=color,
188
+ edgecolor='black',
189
+ linewidth=0.5,
190
+ alpha=0.8,
191
+ )
192
+ ax.add_patch(rect)
193
+ label = f"{v.name}" if i == 0 else f"{v.name}+{i}"
194
+ ax.text(f_idx, abs_time + c + 0.5, label,
195
+ ha='center', va='center', fontsize=8,
196
+ fontweight='bold', color='white')
197
+
198
+ # Formatting
199
+ ax.set_xlim(-0.5, n_fus - 0.5)
200
+ ax.set_ylim(-0.5, L + 0.5)
201
+ ax.set_xticks(range(n_fus))
202
+ ax.set_xticklabels(fu_names, fontsize=10)
203
+ ax.set_yticks(range(L))
204
+ ax.set_ylabel("Clock Cycle", fontsize=12)
205
+ ax.set_xlabel("Functional Unit", fontsize=12)
206
+ ax.set_title(f"{title}\nI={I}, L={L}, copies={n_copies}", fontsize=14)
207
+ ax.invert_yaxis()
208
+ ax.grid(True, alpha=0.3)
209
+
210
+ # Legend for warps
211
+ legend_patches = []
212
+ for w in range(machine.num_warps):
213
+ instrs = wa.instructions_on_warp(w)
214
+ if instrs:
215
+ label = wa.warp_names.get(w, f"Warp {w}")
216
+ legend_patches.append(
217
+ mpatches.Patch(color=WARP_COLORS[w % len(WARP_COLORS)],
218
+ label=f"{label}: {', '.join(instrs)}")
219
+ )
220
+ ax.legend(handles=legend_patches, loc='upper right', fontsize=8)
221
+
222
+ plt.tight_layout()
223
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
224
+ plt.close()
225
+ print(f"Schedule figure saved to {output_path}")
226
+
227
+
228
+ def print_modular_rrt(
229
+ graph: DependenceGraph,
230
+ schedule: ModuloScheduleResult,
231
+ ) -> str:
232
+ """Print the modular RRT as a table."""
233
+ from twill.modulo_scheduler import compute_modular_rrt
234
+
235
+ mod_rrt = compute_modular_rrt(graph, schedule)
236
+ I = schedule.I
237
+ fu_names = graph.machine.functional_units
238
+
239
+ lines = [f"Modular RRT (I={I}):"]
240
+ header = " t β”‚ " + " β”‚ ".join(f"{fu:^8s}" for fu in fu_names) + " β”‚"
241
+ lines.append(header)
242
+ lines.append("─" * len(header))
243
+
244
+ for t in range(I):
245
+ row = f" {t:2d} β”‚ "
246
+ for f_idx in range(len(fu_names)):
247
+ val = mod_rrt[t, f_idx]
248
+ cap = graph.machine.capacity_vector[f_idx]
249
+ marker = "!" if val > cap else " "
250
+ row += f" {val:^6d}{marker} β”‚ "
251
+ lines.append(row)
252
+
253
+ return "\n".join(lines)