File size: 8,211 Bytes
6b0a12a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | """
Visualization: Generate schedule diagrams and warp assignment views.
Based on the figures in the paper (Figures 1, 3, 7, 9).
"""
import numpy as np
from typing import Dict, List, Optional
from twill.graph import DependenceGraph
from twill.smt_joint import JointSWPWSResult
from twill.modulo_scheduler import ModuloScheduleResult
# Color palette for different instruction types / warps
WARP_COLORS = [
'#4CAF50', # Green - variable latency / producer
'#E91E63', # Pink - compute / TC
'#2196F3', # Blue - compute / EXP
'#FF9800', # Orange - compute / misc
'#9C27B0', # Purple
'#00BCD4', # Cyan
'#795548', # Brown
'#607D8B', # Blue-grey
]
FU_COLORS = {
'TC': '#E91E63', # Pink for Tensor Core
'EXP': '#2196F3', # Blue for Exponential
'TMA': '#4CAF50', # Green for TMA loads
'TMEM': '#FF9800', # Orange for Tensor Memory
}
def visualize_schedule(
graph: DependenceGraph,
result: JointSWPWSResult,
output_path: Optional[str] = None,
title: str = "Twill Schedule",
) -> str:
"""Generate a text-based visualization of the schedule.
Shows a timeline with instructions placed at their scheduled cycles,
colored by warp assignment (in the text representation, shown as markers).
Args:
graph: The dependence graph
result: The joint SWP+WS result
output_path: If provided, also generate a matplotlib figure
title: Title for the visualization
Returns:
String representation of the schedule
"""
I = result.I
L = result.length
n_copies = result.num_copies
M = result.schedule
wa = result.warp_assignment
lines = []
lines.append(f"β{'β' * 60}β")
lines.append(f"β {title:^58s} β")
lines.append(f"β I={I}, L={L}, copies={n_copies}{' ' * (58 - len(f'I={I}, L={L}, copies={n_copies}'))}β")
lines.append(f"β {'β' * 60}β£")
# Header: functional units
fu_names = graph.machine.functional_units
header = "Cycle β"
for fu in fu_names:
header += f" {fu:^8s} β"
header += " Warp "
lines.append(f"β {header:<58s} β")
lines.append(f"β {'β' * (len(header)):^58s} β")
# Build timeline
for t in range(L):
# Find what's scheduled at this cycle
active_ops = []
for v in graph.V:
for i in range(n_copies):
abs_time = M[v.name] + i * I
if abs_time == t:
active_ops.append((v, i))
if active_ops:
for v, i in active_ops:
warp = wa.warp_of(v.name)
# Build functional unit usage string
fu_str = f" {t:3d} β"
for f_idx in range(len(fu_names)):
usage = int(v.rrt[:, f_idx].sum())
if usage > 0:
fu_str += f" {v.name:^8s} β"
else:
fu_str += f" {'Β·':^8s} β"
fu_str += f" W{warp} "
if i > 0:
fu_str += f"(i+{i})"
lines.append(f"β {fu_str:<58s} β")
else:
fu_str = f" {t:3d} β"
for _ in fu_names:
fu_str += f" {'Β·':^8s} β"
fu_str += " "
lines.append(f"β {fu_str:<58s} β")
lines.append(f"β {'β' * 60}β£")
# Warp assignment summary
lines.append(f"β {'Warp Assignments:':^58s} β")
for w in range(graph.machine.num_warps):
instrs = wa.instructions_on_warp(w)
if instrs:
label = wa.warp_names.get(w, f"Warp {w}")
instr_str = f" {label}: {', '.join(instrs)}"
lines.append(f"β {instr_str:<58s} β")
# Cross-warp barriers
barriers = []
for edge in graph.E:
src_warp = wa.warp_of(edge.src)
dst_warp = wa.warp_of(edge.dst)
if src_warp != dst_warp:
barriers.append(f" {edge.src}(W{src_warp}) β {edge.dst}(W{dst_warp})")
if barriers:
lines.append(f"β {'':^58s} β")
lines.append(f"β {'Cross-Warp Barriers:':^58s} β")
for b in barriers:
lines.append(f"β {b:<58s} β")
lines.append(f"β{'β' * 60}β")
text_viz = "\n".join(lines)
# Optionally generate matplotlib figure
if output_path:
_generate_matplotlib_figure(graph, result, output_path, title)
return text_viz
def _generate_matplotlib_figure(
graph: DependenceGraph,
result: JointSWPWSResult,
output_path: str,
title: str,
):
"""Generate a matplotlib figure of the schedule (Gantt chart style)."""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
except ImportError:
print("matplotlib not available for figure generation")
return
I = result.I
L = result.length
n_copies = result.num_copies
M = result.schedule
wa = result.warp_assignment
machine = graph.machine
fig, ax = plt.subplots(1, 1, figsize=(14, max(6, L * 0.4)))
# Y-axis: time (cycles), X-axis: functional units
fu_names = machine.functional_units
n_fus = len(fu_names)
bar_width = 0.8
for v in graph.V:
warp = wa.warp_of(v.name)
color = WARP_COLORS[warp % len(WARP_COLORS)]
for i in range(n_copies):
abs_time = M[v.name] + i * I
if abs_time >= L:
continue
for c in range(v.cycles):
for f_idx in range(n_fus):
if v.rrt[c, f_idx] > 0:
rect = mpatches.FancyBboxPatch(
(f_idx - bar_width / 2, abs_time + c),
bar_width, 1,
boxstyle="round,pad=0.05",
facecolor=color,
edgecolor='black',
linewidth=0.5,
alpha=0.8,
)
ax.add_patch(rect)
label = f"{v.name}" if i == 0 else f"{v.name}+{i}"
ax.text(f_idx, abs_time + c + 0.5, label,
ha='center', va='center', fontsize=8,
fontweight='bold', color='white')
# Formatting
ax.set_xlim(-0.5, n_fus - 0.5)
ax.set_ylim(-0.5, L + 0.5)
ax.set_xticks(range(n_fus))
ax.set_xticklabels(fu_names, fontsize=10)
ax.set_yticks(range(L))
ax.set_ylabel("Clock Cycle", fontsize=12)
ax.set_xlabel("Functional Unit", fontsize=12)
ax.set_title(f"{title}\nI={I}, L={L}, copies={n_copies}", fontsize=14)
ax.invert_yaxis()
ax.grid(True, alpha=0.3)
# Legend for warps
legend_patches = []
for w in range(machine.num_warps):
instrs = wa.instructions_on_warp(w)
if instrs:
label = wa.warp_names.get(w, f"Warp {w}")
legend_patches.append(
mpatches.Patch(color=WARP_COLORS[w % len(WARP_COLORS)],
label=f"{label}: {', '.join(instrs)}")
)
ax.legend(handles=legend_patches, loc='upper right', fontsize=8)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Schedule figure saved to {output_path}")
def print_modular_rrt(
graph: DependenceGraph,
schedule: ModuloScheduleResult,
) -> str:
"""Print the modular RRT as a table."""
from twill.modulo_scheduler import compute_modular_rrt
mod_rrt = compute_modular_rrt(graph, schedule)
I = schedule.I
fu_names = graph.machine.functional_units
lines = [f"Modular RRT (I={I}):"]
header = " t β " + " β ".join(f"{fu:^8s}" for fu in fu_names) + " β"
lines.append(header)
lines.append("β" * len(header))
for t in range(I):
row = f" {t:2d} β "
for f_idx in range(len(fu_names)):
val = mod_rrt[t, f_idx]
cap = graph.machine.capacity_vector[f_idx]
marker = "!" if val > cap else " "
row += f" {val:^6d}{marker} β "
lines.append(row)
return "\n".join(lines)
|