| """ |
| Visualization utilities for training results. |
| |
| Generates: |
| - Reward curves (line plots) |
| - Agent comparison (bar charts) |
| - Metrics radar charts |
| """ |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| from typing import Dict, List, Optional |
|
|
|
|
| |
|
|
| COLORS = { |
| "Random": "#e74c3c", |
| "Rule-Based": "#3498db", |
| "Q-Learning": "#2ecc71", |
| } |
|
|
| plt.rcParams.update({ |
| "figure.facecolor": "#1a1a2e", |
| "axes.facecolor": "#16213e", |
| "axes.edgecolor": "#0f3460", |
| "axes.labelcolor": "#e0e0e0", |
| "text.color": "#e0e0e0", |
| "xtick.color": "#a0a0a0", |
| "ytick.color": "#a0a0a0", |
| "grid.color": "#0f3460", |
| "grid.alpha": 0.5, |
| "font.size": 11, |
| "font.family": "sans-serif", |
| }) |
|
|
|
|
| def _smooth(values: List[float], window: int = 10) -> np.ndarray: |
| """Apply a moving average to smooth noisy curves.""" |
| if len(values) < window: |
| return np.array(values) |
| kernel = np.ones(window) / window |
| return np.convolve(values, kernel, mode="valid") |
|
|
|
|
| |
|
|
| def plot_rewards( |
| agent_rewards: Dict[str, List[float]], |
| save_path: str = "logs/reward_curves.png", |
| title: str = "Training Reward Curves", |
| ): |
| """Plot reward curves for multiple agents. |
| |
| Args: |
| agent_rewards: Dict mapping agent_name β list of episode rewards. |
| save_path: File path to save the plot. |
| title: Plot title. |
| """ |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) |
|
|
| |
| for name, rewards in agent_rewards.items(): |
| color = COLORS.get(name, "#ffffff") |
| ax1.plot(rewards, alpha=0.3, color=color, linewidth=0.8) |
| smoothed = _smooth(rewards, window=15) |
| ax1.plot( |
| range(len(rewards) - len(smoothed), len(rewards)), |
| smoothed, |
| color=color, |
| linewidth=2.5, |
| label=f"{name} (smoothed)", |
| ) |
|
|
| ax1.set_xlabel("Episode") |
| ax1.set_ylabel("Total Reward") |
| ax1.set_title("Raw + Smoothed Rewards") |
| ax1.legend(loc="upper left", framealpha=0.7) |
| ax1.grid(True) |
|
|
| |
| for name, rewards in agent_rewards.items(): |
| color = COLORS.get(name, "#ffffff") |
| cumavg = np.cumsum(rewards) / np.arange(1, len(rewards) + 1) |
| ax2.plot(cumavg, color=color, linewidth=2, label=name) |
|
|
| ax2.set_xlabel("Episode") |
| ax2.set_ylabel("Cumulative Average Reward") |
| ax2.set_title("Cumulative Average Performance") |
| ax2.legend(loc="lower right", framealpha=0.7) |
| ax2.grid(True) |
|
|
| fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| plt.close() |
| print(f" π Saved reward curves β {save_path}") |
|
|
|
|
| |
|
|
| def plot_comparison( |
| comparison_data: Dict[str, Dict[str, float]], |
| save_path: str = "logs/agent_comparison.png", |
| ): |
| """Plot a grouped bar chart comparing agent performance. |
| |
| Args: |
| comparison_data: Dict mapping agent_name β {metric_name β value}. |
| save_path: File path to save the plot. |
| """ |
| agents = list(comparison_data.keys()) |
| metric_names = ["avg_reward", "task_completion", "message_response", "efficiency"] |
| metric_labels = ["Avg Reward", "Task Completion", "Msg Response", "Efficiency/100"] |
|
|
| fig, axes = plt.subplots(1, 4, figsize=(18, 5)) |
|
|
| for idx, (metric, label) in enumerate(zip(metric_names, metric_labels)): |
| ax = axes[idx] |
| values = [comparison_data[a].get(metric, 0) for a in agents] |
| colors = [COLORS.get(a, "#ffffff") for a in agents] |
|
|
| bars = ax.bar(agents, values, color=colors, edgecolor="#ffffff33", linewidth=0.5) |
|
|
| |
| for bar, val in zip(bars, values): |
| ax.text( |
| bar.get_x() + bar.get_width() / 2, |
| bar.get_height() + max(values) * 0.02, |
| f"{val:.2f}", |
| ha="center", |
| va="bottom", |
| fontsize=10, |
| fontweight="bold", |
| ) |
|
|
| ax.set_title(label, fontsize=12) |
| ax.set_ylabel(label) |
| ax.tick_params(axis="x", rotation=30) |
| ax.grid(axis="y", alpha=0.3) |
|
|
| fig.suptitle("Agent Comparison", fontsize=14, fontweight="bold", y=1.02) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| plt.close() |
| print(f" π Saved comparison chart β {save_path}") |
|
|
|
|
| |
|
|
| def plot_metrics( |
| metrics: Dict[str, float], |
| save_path: str = "logs/rl_metrics.png", |
| title: str = "Agent Performance Metrics", |
| ): |
| """Plot a radar/spider chart of metrics. |
| |
| Args: |
| metrics: Dict of metric_name β value (0β1 scale or normalized). |
| save_path: File path to save the plot. |
| title: Plot title. |
| """ |
| |
| display_metrics = { |
| "Completion": metrics.get("task_completion_rate", 0), |
| "Hi-Priority": metrics.get("high_priority_completion", 0), |
| "Msg Response": metrics.get("message_response_rate", 0), |
| "Efficiency": metrics.get("efficiency_score", 0) / 100, |
| "No Conflicts": max(0, 1 - metrics.get("conflict_count", 0) / 5), |
| } |
|
|
| categories = list(display_metrics.keys()) |
| values = list(display_metrics.values()) |
|
|
| |
| values += [values[0]] |
| angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist() |
| angles += [angles[0]] |
|
|
| fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) |
|
|
| ax.fill(angles, values, color="#2ecc71", alpha=0.25) |
| ax.plot(angles, values, color="#2ecc71", linewidth=2.5, marker="o", markersize=8) |
|
|
| ax.set_xticks(angles[:-1]) |
| ax.set_xticklabels(categories, fontsize=12) |
| ax.set_ylim(0, 1) |
| ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0]) |
| ax.set_yticklabels(["20%", "40%", "60%", "80%", "100%"], fontsize=9) |
|
|
| |
| for angle, value, cat in zip(angles[:-1], values[:-1], categories): |
| ax.text( |
| angle, value + 0.08, f"{value:.0%}", |
| ha="center", va="center", fontsize=10, fontweight="bold", |
| ) |
|
|
| ax.set_title(title, fontsize=14, fontweight="bold", pad=20) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") |
| plt.close() |
| print(f" π Saved metrics radar β {save_path}") |
|
|