""" Visualization: Generate schedule diagrams and warp assignment views. Based on the figures in the paper (Figures 1, 3, 7, 9). """ import numpy as np from typing import Dict, List, Optional from twill.graph import DependenceGraph from twill.smt_joint import JointSWPWSResult from twill.modulo_scheduler import ModuloScheduleResult # Color palette for different instruction types / warps WARP_COLORS = [ '#4CAF50', # Green - variable latency / producer '#E91E63', # Pink - compute / TC '#2196F3', # Blue - compute / EXP '#FF9800', # Orange - compute / misc '#9C27B0', # Purple '#00BCD4', # Cyan '#795548', # Brown '#607D8B', # Blue-grey ] FU_COLORS = { 'TC': '#E91E63', # Pink for Tensor Core 'EXP': '#2196F3', # Blue for Exponential 'TMA': '#4CAF50', # Green for TMA loads 'TMEM': '#FF9800', # Orange for Tensor Memory } def visualize_schedule( graph: DependenceGraph, result: JointSWPWSResult, output_path: Optional[str] = None, title: str = "Twill Schedule", ) -> str: """Generate a text-based visualization of the schedule. Shows a timeline with instructions placed at their scheduled cycles, colored by warp assignment (in the text representation, shown as markers). Args: graph: The dependence graph result: The joint SWP+WS result output_path: If provided, also generate a matplotlib figure title: Title for the visualization Returns: String representation of the schedule """ I = result.I L = result.length n_copies = result.num_copies M = result.schedule wa = result.warp_assignment lines = [] lines.append(f"╔{'═' * 60}╗") lines.append(f"║ {title:^58s} ║") lines.append(f"║ I={I}, L={L}, copies={n_copies}{' ' * (58 - len(f'I={I}, L={L}, copies={n_copies}'))}║") lines.append(f"╠{'═' * 60}╣") # Header: functional units fu_names = graph.machine.functional_units header = "Cycle │" for fu in fu_names: header += f" {fu:^8s} │" header += " Warp " lines.append(f"║ {header:<58s} ║") lines.append(f"║ {'─' * (len(header)):^58s} ║") # Build timeline for t in range(L): # Find what's scheduled at this cycle active_ops = [] for v in graph.V: for i in range(n_copies): abs_time = M[v.name] + i * I if abs_time == t: active_ops.append((v, i)) if active_ops: for v, i in active_ops: warp = wa.warp_of(v.name) # Build functional unit usage string fu_str = f" {t:3d} │" for f_idx in range(len(fu_names)): usage = int(v.rrt[:, f_idx].sum()) if usage > 0: fu_str += f" {v.name:^8s} │" else: fu_str += f" {'·':^8s} │" fu_str += f" W{warp} " if i > 0: fu_str += f"(i+{i})" lines.append(f"║ {fu_str:<58s} ║") else: fu_str = f" {t:3d} │" for _ in fu_names: fu_str += f" {'·':^8s} │" fu_str += " " lines.append(f"║ {fu_str:<58s} ║") lines.append(f"╠{'═' * 60}╣") # Warp assignment summary lines.append(f"║ {'Warp Assignments:':^58s} ║") for w in range(graph.machine.num_warps): instrs = wa.instructions_on_warp(w) if instrs: label = wa.warp_names.get(w, f"Warp {w}") instr_str = f" {label}: {', '.join(instrs)}" lines.append(f"║ {instr_str:<58s} ║") # Cross-warp barriers barriers = [] for edge in graph.E: src_warp = wa.warp_of(edge.src) dst_warp = wa.warp_of(edge.dst) if src_warp != dst_warp: barriers.append(f" {edge.src}(W{src_warp}) → {edge.dst}(W{dst_warp})") if barriers: lines.append(f"║ {'':^58s} ║") lines.append(f"║ {'Cross-Warp Barriers:':^58s} ║") for b in barriers: lines.append(f"║ {b:<58s} ║") lines.append(f"╚{'═' * 60}╝") text_viz = "\n".join(lines) # Optionally generate matplotlib figure if output_path: _generate_matplotlib_figure(graph, result, output_path, title) return text_viz def _generate_matplotlib_figure( graph: DependenceGraph, result: JointSWPWSResult, output_path: str, title: str, ): """Generate a matplotlib figure of the schedule (Gantt chart style).""" try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches except ImportError: print("matplotlib not available for figure generation") return I = result.I L = result.length n_copies = result.num_copies M = result.schedule wa = result.warp_assignment machine = graph.machine fig, ax = plt.subplots(1, 1, figsize=(14, max(6, L * 0.4))) # Y-axis: time (cycles), X-axis: functional units fu_names = machine.functional_units n_fus = len(fu_names) bar_width = 0.8 for v in graph.V: warp = wa.warp_of(v.name) color = WARP_COLORS[warp % len(WARP_COLORS)] for i in range(n_copies): abs_time = M[v.name] + i * I if abs_time >= L: continue for c in range(v.cycles): for f_idx in range(n_fus): if v.rrt[c, f_idx] > 0: rect = mpatches.FancyBboxPatch( (f_idx - bar_width / 2, abs_time + c), bar_width, 1, boxstyle="round,pad=0.05", facecolor=color, edgecolor='black', linewidth=0.5, alpha=0.8, ) ax.add_patch(rect) label = f"{v.name}" if i == 0 else f"{v.name}+{i}" ax.text(f_idx, abs_time + c + 0.5, label, ha='center', va='center', fontsize=8, fontweight='bold', color='white') # Formatting ax.set_xlim(-0.5, n_fus - 0.5) ax.set_ylim(-0.5, L + 0.5) ax.set_xticks(range(n_fus)) ax.set_xticklabels(fu_names, fontsize=10) ax.set_yticks(range(L)) ax.set_ylabel("Clock Cycle", fontsize=12) ax.set_xlabel("Functional Unit", fontsize=12) ax.set_title(f"{title}\nI={I}, L={L}, copies={n_copies}", fontsize=14) ax.invert_yaxis() ax.grid(True, alpha=0.3) # Legend for warps legend_patches = [] for w in range(machine.num_warps): instrs = wa.instructions_on_warp(w) if instrs: label = wa.warp_names.get(w, f"Warp {w}") legend_patches.append( mpatches.Patch(color=WARP_COLORS[w % len(WARP_COLORS)], label=f"{label}: {', '.join(instrs)}") ) ax.legend(handles=legend_patches, loc='upper right', fontsize=8) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() print(f"Schedule figure saved to {output_path}") def print_modular_rrt( graph: DependenceGraph, schedule: ModuloScheduleResult, ) -> str: """Print the modular RRT as a table.""" from twill.modulo_scheduler import compute_modular_rrt mod_rrt = compute_modular_rrt(graph, schedule) I = schedule.I fu_names = graph.machine.functional_units lines = [f"Modular RRT (I={I}):"] header = " t │ " + " │ ".join(f"{fu:^8s}" for fu in fu_names) + " │" lines.append(header) lines.append("─" * len(header)) for t in range(I): row = f" {t:2d} │ " for f_idx in range(len(fu_names)): val = mod_rrt[t, f_idx] cap = graph.machine.capacity_vector[f_idx] marker = "!" if val > cap else " " row += f" {val:^6d}{marker} │ " lines.append(row) return "\n".join(lines)