mahammadaftab's picture
clean initial commit
62851e9
"""
Visualization utilities for training results.
Generates:
- Reward curves (line plots)
- Agent comparison (bar charts)
- Metrics radar charts
"""
import matplotlib
matplotlib.use("Agg") # Non-interactive backend
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Optional
# ─── Style Configuration ─────────────────────────────────────────────────────
COLORS = {
"Random": "#e74c3c", # Red
"Rule-Based": "#3498db", # Blue
"Q-Learning": "#2ecc71", # Green
}
plt.rcParams.update({
"figure.facecolor": "#1a1a2e",
"axes.facecolor": "#16213e",
"axes.edgecolor": "#0f3460",
"axes.labelcolor": "#e0e0e0",
"text.color": "#e0e0e0",
"xtick.color": "#a0a0a0",
"ytick.color": "#a0a0a0",
"grid.color": "#0f3460",
"grid.alpha": 0.5,
"font.size": 11,
"font.family": "sans-serif",
})
def _smooth(values: List[float], window: int = 10) -> np.ndarray:
"""Apply a moving average to smooth noisy curves."""
if len(values) < window:
return np.array(values)
kernel = np.ones(window) / window
return np.convolve(values, kernel, mode="valid")
# ─── Reward Curves ───────────────────────────────────────────────────────────
def plot_rewards(
agent_rewards: Dict[str, List[float]],
save_path: str = "logs/reward_curves.png",
title: str = "Training Reward Curves",
):
"""Plot reward curves for multiple agents.
Args:
agent_rewards: Dict mapping agent_name β†’ list of episode rewards.
save_path: File path to save the plot.
title: Plot title.
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
# Raw rewards (left)
for name, rewards in agent_rewards.items():
color = COLORS.get(name, "#ffffff")
ax1.plot(rewards, alpha=0.3, color=color, linewidth=0.8)
smoothed = _smooth(rewards, window=15)
ax1.plot(
range(len(rewards) - len(smoothed), len(rewards)),
smoothed,
color=color,
linewidth=2.5,
label=f"{name} (smoothed)",
)
ax1.set_xlabel("Episode")
ax1.set_ylabel("Total Reward")
ax1.set_title("Raw + Smoothed Rewards")
ax1.legend(loc="upper left", framealpha=0.7)
ax1.grid(True)
# Cumulative average (right)
for name, rewards in agent_rewards.items():
color = COLORS.get(name, "#ffffff")
cumavg = np.cumsum(rewards) / np.arange(1, len(rewards) + 1)
ax2.plot(cumavg, color=color, linewidth=2, label=name)
ax2.set_xlabel("Episode")
ax2.set_ylabel("Cumulative Average Reward")
ax2.set_title("Cumulative Average Performance")
ax2.legend(loc="lower right", framealpha=0.7)
ax2.grid(True)
fig.suptitle(title, fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f" πŸ“Š Saved reward curves β†’ {save_path}")
# ─── Agent Comparison Bar Chart ──────────────────────────────────────────────
def plot_comparison(
comparison_data: Dict[str, Dict[str, float]],
save_path: str = "logs/agent_comparison.png",
):
"""Plot a grouped bar chart comparing agent performance.
Args:
comparison_data: Dict mapping agent_name β†’ {metric_name β†’ value}.
save_path: File path to save the plot.
"""
agents = list(comparison_data.keys())
metric_names = ["avg_reward", "task_completion", "message_response", "efficiency"]
metric_labels = ["Avg Reward", "Task Completion", "Msg Response", "Efficiency/100"]
fig, axes = plt.subplots(1, 4, figsize=(18, 5))
for idx, (metric, label) in enumerate(zip(metric_names, metric_labels)):
ax = axes[idx]
values = [comparison_data[a].get(metric, 0) for a in agents]
colors = [COLORS.get(a, "#ffffff") for a in agents]
bars = ax.bar(agents, values, color=colors, edgecolor="#ffffff33", linewidth=0.5)
# Add value labels on bars
for bar, val in zip(bars, values):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + max(values) * 0.02,
f"{val:.2f}",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_title(label, fontsize=12)
ax.set_ylabel(label)
ax.tick_params(axis="x", rotation=30)
ax.grid(axis="y", alpha=0.3)
fig.suptitle("Agent Comparison", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f" πŸ“Š Saved comparison chart β†’ {save_path}")
# ─── Metrics Radar Chart ────────────────────────────────────────────────────
def plot_metrics(
metrics: Dict[str, float],
save_path: str = "logs/rl_metrics.png",
title: str = "Agent Performance Metrics",
):
"""Plot a radar/spider chart of metrics.
Args:
metrics: Dict of metric_name β†’ value (0–1 scale or normalized).
save_path: File path to save the plot.
title: Plot title.
"""
# Normalize metrics to 0–1 scale
display_metrics = {
"Completion": metrics.get("task_completion_rate", 0),
"Hi-Priority": metrics.get("high_priority_completion", 0),
"Msg Response": metrics.get("message_response_rate", 0),
"Efficiency": metrics.get("efficiency_score", 0) / 100,
"No Conflicts": max(0, 1 - metrics.get("conflict_count", 0) / 5),
}
categories = list(display_metrics.keys())
values = list(display_metrics.values())
# Close the radar
values += [values[0]]
angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
angles += [angles[0]]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
ax.fill(angles, values, color="#2ecc71", alpha=0.25)
ax.plot(angles, values, color="#2ecc71", linewidth=2.5, marker="o", markersize=8)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories, fontsize=12)
ax.set_ylim(0, 1)
ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels(["20%", "40%", "60%", "80%", "100%"], fontsize=9)
# Add value labels
for angle, value, cat in zip(angles[:-1], values[:-1], categories):
ax.text(
angle, value + 0.08, f"{value:.0%}",
ha="center", va="center", fontsize=10, fontweight="bold",
)
ax.set_title(title, fontsize=14, fontweight="bold", pad=20)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f" πŸ“Š Saved metrics radar β†’ {save_path}")