twill-swp-ws / twill /visualization.py
AshenNav's picture
Upload twill/visualization.py with huggingface_hub
6b0a12a verified
"""
Visualization: Generate schedule diagrams and warp assignment views.
Based on the figures in the paper (Figures 1, 3, 7, 9).
"""
import numpy as np
from typing import Dict, List, Optional
from twill.graph import DependenceGraph
from twill.smt_joint import JointSWPWSResult
from twill.modulo_scheduler import ModuloScheduleResult
# Color palette for different instruction types / warps
WARP_COLORS = [
'#4CAF50', # Green - variable latency / producer
'#E91E63', # Pink - compute / TC
'#2196F3', # Blue - compute / EXP
'#FF9800', # Orange - compute / misc
'#9C27B0', # Purple
'#00BCD4', # Cyan
'#795548', # Brown
'#607D8B', # Blue-grey
]
FU_COLORS = {
'TC': '#E91E63', # Pink for Tensor Core
'EXP': '#2196F3', # Blue for Exponential
'TMA': '#4CAF50', # Green for TMA loads
'TMEM': '#FF9800', # Orange for Tensor Memory
}
def visualize_schedule(
graph: DependenceGraph,
result: JointSWPWSResult,
output_path: Optional[str] = None,
title: str = "Twill Schedule",
) -> str:
"""Generate a text-based visualization of the schedule.
Shows a timeline with instructions placed at their scheduled cycles,
colored by warp assignment (in the text representation, shown as markers).
Args:
graph: The dependence graph
result: The joint SWP+WS result
output_path: If provided, also generate a matplotlib figure
title: Title for the visualization
Returns:
String representation of the schedule
"""
I = result.I
L = result.length
n_copies = result.num_copies
M = result.schedule
wa = result.warp_assignment
lines = []
lines.append(f"β•”{'═' * 60}β•—")
lines.append(f"β•‘ {title:^58s} β•‘")
lines.append(f"β•‘ I={I}, L={L}, copies={n_copies}{' ' * (58 - len(f'I={I}, L={L}, copies={n_copies}'))}β•‘")
lines.append(f"β• {'═' * 60}β•£")
# Header: functional units
fu_names = graph.machine.functional_units
header = "Cycle β”‚"
for fu in fu_names:
header += f" {fu:^8s} β”‚"
header += " Warp "
lines.append(f"β•‘ {header:<58s} β•‘")
lines.append(f"β•‘ {'─' * (len(header)):^58s} β•‘")
# Build timeline
for t in range(L):
# Find what's scheduled at this cycle
active_ops = []
for v in graph.V:
for i in range(n_copies):
abs_time = M[v.name] + i * I
if abs_time == t:
active_ops.append((v, i))
if active_ops:
for v, i in active_ops:
warp = wa.warp_of(v.name)
# Build functional unit usage string
fu_str = f" {t:3d} β”‚"
for f_idx in range(len(fu_names)):
usage = int(v.rrt[:, f_idx].sum())
if usage > 0:
fu_str += f" {v.name:^8s} β”‚"
else:
fu_str += f" {'Β·':^8s} β”‚"
fu_str += f" W{warp} "
if i > 0:
fu_str += f"(i+{i})"
lines.append(f"β•‘ {fu_str:<58s} β•‘")
else:
fu_str = f" {t:3d} β”‚"
for _ in fu_names:
fu_str += f" {'Β·':^8s} β”‚"
fu_str += " "
lines.append(f"β•‘ {fu_str:<58s} β•‘")
lines.append(f"β• {'═' * 60}β•£")
# Warp assignment summary
lines.append(f"β•‘ {'Warp Assignments:':^58s} β•‘")
for w in range(graph.machine.num_warps):
instrs = wa.instructions_on_warp(w)
if instrs:
label = wa.warp_names.get(w, f"Warp {w}")
instr_str = f" {label}: {', '.join(instrs)}"
lines.append(f"β•‘ {instr_str:<58s} β•‘")
# Cross-warp barriers
barriers = []
for edge in graph.E:
src_warp = wa.warp_of(edge.src)
dst_warp = wa.warp_of(edge.dst)
if src_warp != dst_warp:
barriers.append(f" {edge.src}(W{src_warp}) β†’ {edge.dst}(W{dst_warp})")
if barriers:
lines.append(f"β•‘ {'':^58s} β•‘")
lines.append(f"β•‘ {'Cross-Warp Barriers:':^58s} β•‘")
for b in barriers:
lines.append(f"β•‘ {b:<58s} β•‘")
lines.append(f"β•š{'═' * 60}╝")
text_viz = "\n".join(lines)
# Optionally generate matplotlib figure
if output_path:
_generate_matplotlib_figure(graph, result, output_path, title)
return text_viz
def _generate_matplotlib_figure(
graph: DependenceGraph,
result: JointSWPWSResult,
output_path: str,
title: str,
):
"""Generate a matplotlib figure of the schedule (Gantt chart style)."""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
except ImportError:
print("matplotlib not available for figure generation")
return
I = result.I
L = result.length
n_copies = result.num_copies
M = result.schedule
wa = result.warp_assignment
machine = graph.machine
fig, ax = plt.subplots(1, 1, figsize=(14, max(6, L * 0.4)))
# Y-axis: time (cycles), X-axis: functional units
fu_names = machine.functional_units
n_fus = len(fu_names)
bar_width = 0.8
for v in graph.V:
warp = wa.warp_of(v.name)
color = WARP_COLORS[warp % len(WARP_COLORS)]
for i in range(n_copies):
abs_time = M[v.name] + i * I
if abs_time >= L:
continue
for c in range(v.cycles):
for f_idx in range(n_fus):
if v.rrt[c, f_idx] > 0:
rect = mpatches.FancyBboxPatch(
(f_idx - bar_width / 2, abs_time + c),
bar_width, 1,
boxstyle="round,pad=0.05",
facecolor=color,
edgecolor='black',
linewidth=0.5,
alpha=0.8,
)
ax.add_patch(rect)
label = f"{v.name}" if i == 0 else f"{v.name}+{i}"
ax.text(f_idx, abs_time + c + 0.5, label,
ha='center', va='center', fontsize=8,
fontweight='bold', color='white')
# Formatting
ax.set_xlim(-0.5, n_fus - 0.5)
ax.set_ylim(-0.5, L + 0.5)
ax.set_xticks(range(n_fus))
ax.set_xticklabels(fu_names, fontsize=10)
ax.set_yticks(range(L))
ax.set_ylabel("Clock Cycle", fontsize=12)
ax.set_xlabel("Functional Unit", fontsize=12)
ax.set_title(f"{title}\nI={I}, L={L}, copies={n_copies}", fontsize=14)
ax.invert_yaxis()
ax.grid(True, alpha=0.3)
# Legend for warps
legend_patches = []
for w in range(machine.num_warps):
instrs = wa.instructions_on_warp(w)
if instrs:
label = wa.warp_names.get(w, f"Warp {w}")
legend_patches.append(
mpatches.Patch(color=WARP_COLORS[w % len(WARP_COLORS)],
label=f"{label}: {', '.join(instrs)}")
)
ax.legend(handles=legend_patches, loc='upper right', fontsize=8)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Schedule figure saved to {output_path}")
def print_modular_rrt(
graph: DependenceGraph,
schedule: ModuloScheduleResult,
) -> str:
"""Print the modular RRT as a table."""
from twill.modulo_scheduler import compute_modular_rrt
mod_rrt = compute_modular_rrt(graph, schedule)
I = schedule.I
fu_names = graph.machine.functional_units
lines = [f"Modular RRT (I={I}):"]
header = " t β”‚ " + " β”‚ ".join(f"{fu:^8s}" for fu in fu_names) + " β”‚"
lines.append(header)
lines.append("─" * len(header))
for t in range(I):
row = f" {t:2d} β”‚ "
for f_idx in range(len(fu_names)):
val = mod_rrt[t, f_idx]
cap = graph.machine.capacity_vector[f_idx]
marker = "!" if val > cap else " "
row += f" {val:^6d}{marker} β”‚ "
lines.append(row)
return "\n".join(lines)