""" Visualization utilities for training results. Generates: - Reward curves (line plots) - Agent comparison (bar charts) - Metrics radar charts """ import matplotlib matplotlib.use("Agg") # Non-interactive backend import matplotlib.pyplot as plt import numpy as np from typing import Dict, List, Optional # ─── Style Configuration ───────────────────────────────────────────────────── COLORS = { "Random": "#e74c3c", # Red "Rule-Based": "#3498db", # Blue "Q-Learning": "#2ecc71", # Green } plt.rcParams.update({ "figure.facecolor": "#1a1a2e", "axes.facecolor": "#16213e", "axes.edgecolor": "#0f3460", "axes.labelcolor": "#e0e0e0", "text.color": "#e0e0e0", "xtick.color": "#a0a0a0", "ytick.color": "#a0a0a0", "grid.color": "#0f3460", "grid.alpha": 0.5, "font.size": 11, "font.family": "sans-serif", }) def _smooth(values: List[float], window: int = 10) -> np.ndarray: """Apply a moving average to smooth noisy curves.""" if len(values) < window: return np.array(values) kernel = np.ones(window) / window return np.convolve(values, kernel, mode="valid") # ─── Reward Curves ─────────────────────────────────────────────────────────── def plot_rewards( agent_rewards: Dict[str, List[float]], save_path: str = "logs/reward_curves.png", title: str = "Training Reward Curves", ): """Plot reward curves for multiple agents. Args: agent_rewards: Dict mapping agent_name → list of episode rewards. save_path: File path to save the plot. title: Plot title. """ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) # Raw rewards (left) for name, rewards in agent_rewards.items(): color = COLORS.get(name, "#ffffff") ax1.plot(rewards, alpha=0.3, color=color, linewidth=0.8) smoothed = _smooth(rewards, window=15) ax1.plot( range(len(rewards) - len(smoothed), len(rewards)), smoothed, color=color, linewidth=2.5, label=f"{name} (smoothed)", ) ax1.set_xlabel("Episode") ax1.set_ylabel("Total Reward") ax1.set_title("Raw + Smoothed Rewards") ax1.legend(loc="upper left", framealpha=0.7) ax1.grid(True) # Cumulative average (right) for name, rewards in agent_rewards.items(): color = COLORS.get(name, "#ffffff") cumavg = np.cumsum(rewards) / np.arange(1, len(rewards) + 1) ax2.plot(cumavg, color=color, linewidth=2, label=name) ax2.set_xlabel("Episode") ax2.set_ylabel("Cumulative Average Reward") ax2.set_title("Cumulative Average Performance") ax2.legend(loc="lower right", framealpha=0.7) ax2.grid(True) fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f" 📊 Saved reward curves → {save_path}") # ─── Agent Comparison Bar Chart ────────────────────────────────────────────── def plot_comparison( comparison_data: Dict[str, Dict[str, float]], save_path: str = "logs/agent_comparison.png", ): """Plot a grouped bar chart comparing agent performance. Args: comparison_data: Dict mapping agent_name → {metric_name → value}. save_path: File path to save the plot. """ agents = list(comparison_data.keys()) metric_names = ["avg_reward", "task_completion", "message_response", "efficiency"] metric_labels = ["Avg Reward", "Task Completion", "Msg Response", "Efficiency/100"] fig, axes = plt.subplots(1, 4, figsize=(18, 5)) for idx, (metric, label) in enumerate(zip(metric_names, metric_labels)): ax = axes[idx] values = [comparison_data[a].get(metric, 0) for a in agents] colors = [COLORS.get(a, "#ffffff") for a in agents] bars = ax.bar(agents, values, color=colors, edgecolor="#ffffff33", linewidth=0.5) # Add value labels on bars for bar, val in zip(bars, values): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + max(values) * 0.02, f"{val:.2f}", ha="center", va="bottom", fontsize=10, fontweight="bold", ) ax.set_title(label, fontsize=12) ax.set_ylabel(label) ax.tick_params(axis="x", rotation=30) ax.grid(axis="y", alpha=0.3) fig.suptitle("Agent Comparison", fontsize=14, fontweight="bold", y=1.02) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f" 📊 Saved comparison chart → {save_path}") # ─── Metrics Radar Chart ──────────────────────────────────────────────────── def plot_metrics( metrics: Dict[str, float], save_path: str = "logs/rl_metrics.png", title: str = "Agent Performance Metrics", ): """Plot a radar/spider chart of metrics. Args: metrics: Dict of metric_name → value (0–1 scale or normalized). save_path: File path to save the plot. title: Plot title. """ # Normalize metrics to 0–1 scale display_metrics = { "Completion": metrics.get("task_completion_rate", 0), "Hi-Priority": metrics.get("high_priority_completion", 0), "Msg Response": metrics.get("message_response_rate", 0), "Efficiency": metrics.get("efficiency_score", 0) / 100, "No Conflicts": max(0, 1 - metrics.get("conflict_count", 0) / 5), } categories = list(display_metrics.keys()) values = list(display_metrics.values()) # Close the radar values += [values[0]] angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist() angles += [angles[0]] fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) ax.fill(angles, values, color="#2ecc71", alpha=0.25) ax.plot(angles, values, color="#2ecc71", linewidth=2.5, marker="o", markersize=8) ax.set_xticks(angles[:-1]) ax.set_xticklabels(categories, fontsize=12) ax.set_ylim(0, 1) ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0]) ax.set_yticklabels(["20%", "40%", "60%", "80%", "100%"], fontsize=9) # Add value labels for angle, value, cat in zip(angles[:-1], values[:-1], categories): ax.text( angle, value + 0.08, f"{value:.0%}", ha="center", va="center", fontsize=10, fontweight="bold", ) ax.set_title(title, fontsize=14, fontweight="bold", pad=20) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f" 📊 Saved metrics radar → {save_path}")