algorembrant's picture
Upload 76 files
8744e5e verified
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.gridspec import GridSpec
from matplotlib.patches import FancyArrowPatch
import os
import re
def setup_figure(title, rows, cols):
"""Initializes a new figure and grid layout with constrained_layout to avoid warnings."""
fig = plt.figure(figsize=(20, 10), constrained_layout=True)
fig.suptitle(title, fontsize=18, fontweight='bold')
gs = GridSpec(rows, cols, figure=fig)
return fig, gs
def plot_agent_env_loop(ax):
"""MDP & Environment: Agent-Environment Interaction Loop (Flowchart)."""
ax.axis('off')
ax.set_title("Agent-Environment Interaction", fontsize=12, fontweight='bold')
props = dict(boxstyle="round,pad=0.8", fc="ivory", ec="black", lw=1.5)
ax.text(0.5, 0.8, "Agent", ha="center", va="center", bbox=props, fontsize=12)
ax.text(0.5, 0.2, "Environment", ha="center", va="center", bbox=props, fontsize=12)
# Arrows
# Agent to Env: Action
ax.annotate("Action $A_t$", xy=(0.5, 0.35), xytext=(0.5, 0.65),
arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2))
# Env to Agent: State & Reward
ax.annotate("State $S_{t+1}$, Reward $R_{t+1}$", xy=(0.5, 0.65), xytext=(0.5, 0.35),
arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2, color='green'))
def plot_mdp_graph(ax):
"""MDP & Environment: Directed graph with probability-weighted arrows."""
G = nx.DiGraph()
# Corrected syntax: using a dictionary for edge attributes
G.add_edges_from([
('S0', 'S1', {'weight': 0.8}), ('S0', 'S2', {'weight': 0.2}),
('S1', 'S2', {'weight': 1.0}), ('S2', 'S0', {'weight': 0.5}), ('S2', 'S2', {'weight': 0.5})
])
pos = nx.spring_layout(G, seed=42)
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=1500, node_color='lightblue')
nx.draw_networkx_labels(ax=ax, G=G, pos=pos, font_weight='bold')
edge_labels = {(u, v): f"P={d['weight']}" for u, v, d in G.edges(data=True)}
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrowsize=20, edge_color='gray', connectionstyle="arc3,rad=0.1")
nx.draw_networkx_edge_labels(ax=ax, G=G, pos=pos, edge_labels=edge_labels, font_size=9)
ax.set_title("MDP State Transition Graph", fontsize=12, fontweight='bold')
ax.axis('off')
def plot_reward_landscape(fig, gs):
"""MDP & Environment: 3D surface plot of a reward function."""
# Use the first available slot in gs (handled flexibly for dashboard vs save)
try:
ax = fig.add_subplot(gs[0, 1], projection='3d')
except IndexError:
ax = fig.add_subplot(gs[0, 0], projection='3d')
X = np.linspace(-5, 5, 50)
Y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(X, Y)
Z = np.sin(np.sqrt(X**2 + Y**2)) + (X * 0.1) # Simulated reward landscape
surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.9)
ax.set_title("Reward Function Landscape", fontsize=12, fontweight='bold')
ax.set_xlabel('State X')
ax.set_ylabel('State Y')
ax.set_zlabel('Reward R(s)')
def plot_trajectory(ax):
"""MDP & Environment: Trajectory / Episode Sequence."""
ax.set_title("Trajectory Sequence", fontsize=12, fontweight='bold')
states = ['s0', 's1', 's2', 's3', 'sT']
actions = ['a0', 'a1', 'a2', 'a3']
rewards = ['r1', 'r2', 'r3', 'r4']
for i, s in enumerate(states):
ax.text(i, 0.5, s, ha='center', va='center', bbox=dict(boxstyle="circle", fc="white"))
if i < len(actions):
ax.annotate("", xy=(i+0.8, 0.5), xytext=(i+0.2, 0.5), arrowprops=dict(arrowstyle="->"))
ax.text(i+0.5, 0.6, actions[i], ha='center', color='blue')
ax.text(i+0.5, 0.4, rewards[i], ha='center', color='red')
ax.set_xlim(-0.5, len(states)-0.5)
ax.set_ylim(0, 1)
ax.axis('off')
def plot_continuous_space(ax):
"""MDP & Environment: Continuous State/Action Space Visualization."""
np.random.seed(42)
x = np.random.randn(200, 2)
labels = np.linalg.norm(x, axis=1) > 1.0
ax.scatter(x[labels, 0], x[labels, 1], c='coral', alpha=0.6, label='High Reward')
ax.scatter(x[~labels, 0], x[~labels, 1], c='skyblue', alpha=0.6, label='Low Reward')
ax.set_title("Continuous State Space (2D Projection)", fontsize=12, fontweight='bold')
ax.legend(fontsize=8)
def plot_discount_decay(ax):
"""MDP & Environment: Discount Factor (gamma) Effect."""
t = np.arange(0, 20)
for gamma in [0.5, 0.9, 0.99]:
ax.plot(t, gamma**t, marker='o', markersize=4, label=rf"$\gamma={gamma}$")
ax.set_title(r"Discount Factor $\gamma^t$ Decay", fontsize=12, fontweight='bold')
ax.set_xlabel("Time steps (t)")
ax.set_ylabel("Weight")
ax.legend()
ax.grid(True, alpha=0.3)
def plot_value_heatmap(ax):
"""Value & Policy: State-Value Function V(s) Heatmap (Gridworld)."""
grid_size = 5
# Simulate a value landscape where the top right is the goal
values = np.zeros((grid_size, grid_size))
for i in range(grid_size):
for j in range(grid_size):
values[i, j] = -( (grid_size-1-i)**2 + (grid_size-1-j)**2 ) * 0.5
values[-1, -1] = 10.0 # Goal state
cax = ax.matshow(values, cmap='magma')
for (i, j), z in np.ndenumerate(values):
ax.text(j, i, f'{z:0.1f}', ha='center', va='center', color='white' if z < -5 else 'black', fontsize=9)
ax.set_title("State-Value Function V(s) Heatmap", fontsize=12, fontweight='bold', pad=15)
ax.set_xticks(range(grid_size))
ax.set_yticks(range(grid_size))
def plot_backup_diagram(ax):
"""Dynamic Programming: Policy Evaluation Backup Diagram."""
G = nx.DiGraph()
G.add_node("s", layer=0)
G.add_node("a1", layer=1); G.add_node("a2", layer=1)
G.add_node("s'_1", layer=2); G.add_node("s'_2", layer=2); G.add_node("s'_3", layer=2)
G.add_edges_from([("s", "a1"), ("s", "a2")])
G.add_edges_from([("a1", "s'_1"), ("a1", "s'_2"), ("a2", "s'_3")])
pos = {
"s": (0.5, 1),
"a1": (0.25, 0.5), "a2": (0.75, 0.5),
"s'_1": (0.1, 0), "s'_2": (0.4, 0), "s'_3": (0.75, 0)
}
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["s", "s'_1", "s'_2", "s'_3"], node_size=800, node_color='white', edgecolors='black')
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["a1", "a2"], node_size=300, node_color='black') # Action nodes are solid black dots
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "s'_1": "s'", "s'_2": "s'", "s'_3": "s'"}, font_size=10)
ax.set_title("DP Policy Eval Backup", fontsize=12, fontweight='bold')
ax.set_ylim(-0.2, 1.2)
ax.axis('off')
def plot_action_value_q(ax):
"""Value & Policy: Action-Value Function Q(s,a) (Heatmap per action stack)."""
grid = np.random.rand(3, 3)
ax.imshow(grid, cmap='YlGnBu')
for (i, j), z in np.ndenumerate(grid):
ax.text(j, i, f'{z:0.1f}', ha='center', va='center', fontsize=8)
ax.set_title(r"Action-Value $Q(s, a_{up})$", fontsize=12, fontweight='bold')
ax.set_xticks([]); ax.set_yticks([])
def plot_policy_arrows(ax):
"""Value & Policy: Policy π(s) as arrow overlays on grid."""
grid_size = 4
ax.set_xlim(-0.5, grid_size-0.5)
ax.set_ylim(-0.5, grid_size-0.5)
for i in range(grid_size):
for j in range(grid_size):
dx, dy = np.random.choice([0, 0.3, -0.3]), np.random.choice([0, 0.3, -0.3])
if dx == 0 and dy == 0: dx = 0.3
ax.add_patch(FancyArrowPatch((j, i), (j+dx, i+dy), arrowstyle='->', mutation_scale=15))
ax.set_title(r"Policy $\pi(s)$ Arrows", fontsize=12, fontweight='bold')
ax.set_xticks(range(grid_size)); ax.set_yticks(range(grid_size)); ax.grid(True, alpha=0.2)
def plot_advantage_function(ax):
"""Value & Policy: Advantage Function A(s,a) = Q-V."""
actions = ['A1', 'A2', 'A3', 'A4']
advantage = [2.1, -1.2, 0.5, -0.8]
colors = ['green' if v > 0 else 'red' for v in advantage]
ax.bar(actions, advantage, color=colors, alpha=0.7)
ax.axhline(0, color='black', lw=1)
ax.set_title(r"Advantage $A(s, a)$", fontsize=12, fontweight='bold')
ax.set_ylabel("Value")
def plot_policy_improvement(ax):
"""Dynamic Programming: Policy Improvement (Before vs After)."""
ax.axis('off')
ax.set_title("Policy Improvement", fontsize=12, fontweight='bold')
ax.text(0.2, 0.5, r"$\pi_{old}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgrey"))
ax.annotate("", xy=(0.8, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", lw=2))
ax.text(0.5, 0.6, "Greedy\nImprovement", ha='center', fontsize=9)
ax.text(0.85, 0.5, r"$\pi_{new}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgreen"))
def plot_value_iteration_backup(ax):
"""Dynamic Programming: Value Iteration Backup Diagram (Max over actions)."""
G = nx.DiGraph()
pos = {"s": (0.5, 1), "max": (0.5, 0.5), "s1": (0.2, 0), "s2": (0.5, 0), "s3": (0.8, 0)}
G.add_nodes_from(pos.keys())
G.add_edges_from([("s", "max"), ("max", "s1"), ("max", "s2"), ("max", "s3")])
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color='white', edgecolors='black')
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "max": "max", "s1": "s'", "s2": "s'", "s3": "s'"}, font_size=9)
ax.set_title("Value Iteration Backup", fontsize=12, fontweight='bold')
ax.axis('off')
def plot_policy_iteration_cycle(ax):
"""Dynamic Programming: Policy Iteration Full Cycle Flowchart."""
ax.axis('off')
ax.set_title("Policy Iteration Cycle", fontsize=12, fontweight='bold')
props = dict(boxstyle="round", fc="aliceblue", ec="black")
ax.text(0.5, 0.8, r"Policy Evaluation" + "\n" + r"$V \leftarrow V^\pi$", ha="center", bbox=props)
ax.text(0.5, 0.2, r"Policy Improvement" + "\n" + r"$\pi \leftarrow \text{greedy}(V)$", ha="center", bbox=props)
ax.annotate("", xy=(0.7, 0.3), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
ax.annotate("", xy=(0.3, 0.7), xytext=(0.3, 0.3), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
def plot_mc_backup(ax):
"""Monte Carlo: Backup diagram (Full trajectory until terminal sT)."""
ax.axis('off')
ax.set_title("Monte Carlo Backup", fontsize=12, fontweight='bold')
nodes = ['s', 's1', 's2', 'sT']
pos = {n: (0.5, 0.9 - i*0.25) for i, n in enumerate(nodes)}
for i in range(len(nodes)-1):
ax.annotate("", xy=pos[nodes[i+1]], xytext=pos[nodes[i]], arrowprops=dict(arrowstyle="->", lw=1.5))
ax.text(pos[nodes[i]][0]+0.05, pos[nodes[i]][1], nodes[i], va='center')
ax.text(pos['sT'][0]+0.05, pos['sT'][1], 'sT', va='center', fontweight='bold')
ax.annotate("Update V(s) using G", xy=(0.3, 0.9), xytext=(0.3, 0.15), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.3"))
def plot_mcts(ax):
"""Monte Carlo: Monte Carlo Tree Search (MCTS) tree diagram."""
G = nx.balanced_tree(2, 2, create_using=nx.DiGraph())
pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot') if 'pygraphviz' in globals() else nx.shell_layout(G)
# Simple tree fallback
pos = {0:(0,0), 1:(-1,-1), 2:(1,-1), 3:(-1.5,-2), 4:(-0.5,-2), 5:(0.5,-2), 6:(1.5,-2)}
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=300, node_color='lightyellow', edgecolors='black')
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
ax.set_title("MCTS Tree", fontsize=12, fontweight='bold')
ax.axis('off')
def plot_importance_sampling(ax):
"""Monte Carlo: Importance Sampling Ratio Flow."""
ax.axis('off')
ax.set_title("Importance Sampling", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, r"$\pi(a|s)$", bbox=dict(boxstyle="circle", fc="lightgreen"), ha='center')
ax.text(0.5, 0.2, r"$b(a|s)$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
ax.annotate(r"$\rho = \frac{\pi}{b}$", xy=(0.7, 0.5), fontsize=15)
ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<->", lw=2))
def plot_td_backup(ax):
"""Temporal Difference: TD(0) 1-step backup."""
ax.axis('off')
ax.set_title("TD(0) Backup", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, "s", bbox=dict(boxstyle="circle", fc="white"), ha='center')
ax.text(0.5, 0.2, "s'", bbox=dict(boxstyle="circle", fc="white"), ha='center')
ax.annotate(r"$R + \gamma V(s')$", xy=(0.5, 0.4), ha='center', color='blue')
ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<-", lw=2))
def plot_nstep_td(ax):
"""Temporal Difference: n-step TD backup."""
ax.axis('off')
ax.set_title("n-step TD Backup", fontsize=12, fontweight='bold')
for i in range(4):
ax.text(0.5, 0.9-i*0.2, f"s_{i}", bbox=dict(boxstyle="circle", fc="white"), ha='center', fontsize=8)
if i < 3: ax.annotate("", xy=(0.5, 0.75-i*0.2), xytext=(0.5, 0.85-i*0.2), arrowprops=dict(arrowstyle="->"))
ax.annotate(r"$G_t^{(n)}$", xy=(0.7, 0.5), fontsize=12, color='red')
def plot_eligibility_traces(ax):
"""Temporal Difference: TD(lambda) Eligibility Traces decay curve."""
t = np.arange(0, 50)
# Simulate multiple highlights (visits)
trace = np.zeros_like(t, dtype=float)
visits = [5, 20, 35]
for v in visits:
trace[v:] += (0.8 ** np.arange(len(t)-v))
ax.plot(t, trace, color='brown', lw=2)
ax.set_title(r"Eligibility Trace $z_t(\lambda)$", fontsize=12, fontweight='bold')
ax.set_xlabel("Time")
ax.fill_between(t, trace, color='brown', alpha=0.1)
def plot_sarsa_backup(ax):
"""Temporal Difference: SARSA (On-policy) backup."""
ax.axis('off')
ax.set_title("SARSA Backup", fontsize=12, fontweight='bold')
ax.text(0.5, 0.9, "(s,a)", ha='center')
ax.text(0.5, 0.1, "(s',a')", ha='center')
ax.annotate("", xy=(0.5, 0.2), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='orange'))
ax.text(0.6, 0.5, "On-policy", rotation=90)
def plot_q_learning_backup(ax):
"""Temporal Difference: Q-Learning (Off-policy) backup."""
ax.axis('off')
ax.set_title("Q-Learning Backup", fontsize=12, fontweight='bold')
ax.text(0.5, 0.9, "(s,a)", ha='center')
ax.text(0.5, 0.1, r"$\max_{a'} Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="lightcyan"))
ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='blue'))
def plot_double_q(ax):
"""Temporal Difference: Double Q-Learning / Double DQN."""
ax.axis('off')
ax.set_title("Double Q-Learning", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, "Network A", bbox=dict(fc="lightyellow"), ha='center')
ax.text(0.5, 0.2, "Network B", bbox=dict(fc="lightcyan"), ha='center')
ax.annotate("Select $a^*$", xy=(0.3, 0.8), xytext=(0.5, 0.85), arrowprops=dict(arrowstyle="->"))
ax.annotate("Eval $Q(s', a^*)$", xy=(0.7, 0.2), xytext=(0.5, 0.15), arrowprops=dict(arrowstyle="->"))
def plot_dueling_dqn(ax):
"""Temporal Difference: Dueling DQN Architecture."""
ax.axis('off')
ax.set_title("Dueling DQN", fontsize=12, fontweight='bold')
ax.text(0.1, 0.5, "Backbone", bbox=dict(fc="lightgrey"), ha='center', rotation=90)
ax.text(0.5, 0.7, "V(s)", bbox=dict(fc="lightgreen"), ha='center')
ax.text(0.5, 0.3, "A(s,a)", bbox=dict(fc="lightblue"), ha='center')
ax.text(0.9, 0.5, "Q(s,a)", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
ax.annotate("", xy=(0.35, 0.7), xytext=(0.15, 0.55), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.35, 0.3), xytext=(0.15, 0.45), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.75, 0.55), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.75, 0.45), xytext=(0.6, 0.3), arrowprops=dict(arrowstyle="->"))
def plot_prioritized_replay(ax):
"""Temporal Difference: Prioritized Experience Replay (PER)."""
priorities = np.random.pareto(3, 100)
ax.hist(priorities, bins=20, color='teal', alpha=0.7)
ax.set_title("Prioritized Replay (TD-Error)", fontsize=12, fontweight='bold')
ax.set_xlabel("Priority $P_i$")
ax.set_ylabel("Count")
def plot_rainbow_dqn(ax):
"""Temporal Difference: Rainbow DQN Composite."""
ax.axis('off')
ax.set_title("Rainbow DQN", fontsize=12, fontweight='bold')
features = ["Double", "Dueling", "PER", "Noisy", "Distributional", "n-step"]
for i, f in enumerate(features):
ax.text(0.5, 0.9 - i*0.15, f, ha='center', bbox=dict(boxstyle="round", fc="ghostwhite"), fontsize=8)
def plot_linear_fa(ax):
"""Function Approximation: Linear Function Approximation."""
ax.axis('off')
ax.set_title("Linear Function Approx", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, r"$\phi(s)$ Features", ha='center', bbox=dict(fc="white"))
ax.text(0.5, 0.2, r"$w^T \phi(s)$", ha='center', bbox=dict(fc="lightgrey"))
ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="->", lw=2))
def plot_nn_layers(ax):
"""Function Approximation: Neural Network Layers diagram."""
ax.axis('off')
ax.set_title("NN Layers (Deep RL)", fontsize=12, fontweight='bold')
layers = [4, 8, 8, 2]
for i, l in enumerate(layers):
for j in range(l):
ax.scatter(i*0.3, j*0.1 - l*0.05, s=20, c='black')
ax.set_xlim(-0.1, 1.0)
ax.set_ylim(-0.5, 0.5)
def plot_computation_graph(ax):
"""Function Approximation: Computation Graph / Backprop Flow."""
ax.axis('off')
ax.set_title("Computation Graph (DAG)", fontsize=12, fontweight='bold')
ax.text(0.1, 0.5, "Input", bbox=dict(boxstyle="circle", fc="white"))
ax.text(0.5, 0.5, "Op", bbox=dict(boxstyle="square", fc="lightgrey"))
ax.text(0.9, 0.5, "Loss", bbox=dict(boxstyle="circle", fc="salmon"))
ax.annotate("", xy=(0.35, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.75, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
ax.annotate("Grad", xy=(0.1, 0.3), xytext=(0.9, 0.3), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.2"))
def plot_target_network(ax):
"""Function Approximation: Target Network concept."""
ax.axis('off')
ax.set_title("Target Network Updates", fontsize=12, fontweight='bold')
ax.text(0.3, 0.8, r"$Q_\theta$ (Active)", bbox=dict(fc="lightgreen"))
ax.text(0.7, 0.8, r"$Q_{\theta^-}$ (Target)", bbox=dict(fc="lightblue"))
ax.annotate("periodic copy", xy=(0.6, 0.8), xytext=(0.4, 0.8), arrowprops=dict(arrowstyle="<-", ls='--'))
def plot_ppo_clip(ax):
"""Policy Gradients: PPO Clipped Surrogate Objective."""
epsilon = 0.2
r = np.linspace(0.5, 1.5, 100)
advantage = 1.0
surr1 = r * advantage
surr2 = np.clip(r, 1-epsilon, 1+epsilon) * advantage
ax.plot(r, surr1, '--', label="r*A")
ax.plot(r, np.minimum(surr1, surr2), 'r', label="min(r*A, clip*A)")
ax.set_title("PPO-Clip Objective", fontsize=12, fontweight='bold')
ax.legend(fontsize=8)
ax.axvline(1, color='gray', linestyle=':')
def plot_trpo_trust_region(ax):
"""Policy Gradients: TRPO Trust Region / KL Constraint."""
ax.set_title("TRPO Trust Region", fontsize=12, fontweight='bold')
circle = plt.Circle((0.5, 0.5), 0.3, color='blue', fill=False, label="KL Constraint")
ax.add_artist(circle)
ax.scatter(0.5, 0.5, c='black', label=r"$\pi_{old}$")
ax.arrow(0.5, 0.5, 0.15, 0.1, head_width=0.03, color='red', label="Update")
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
ax.axis('off')
def plot_a3c_multi_worker(ax):
"""Actor-Critic: Asynchronous Multi-worker (A3C)."""
ax.axis('off')
ax.set_title("A3C Multi-worker", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, "Global Parameters", bbox=dict(fc="gold"), ha='center')
for i in range(3):
ax.text(0.2 + i*0.3, 0.2, f"Worker {i+1}", bbox=dict(fc="lightgrey"), ha='center', fontsize=8)
ax.annotate("", xy=(0.5, 0.7), xytext=(0.2 + i*0.3, 0.3), arrowprops=dict(arrowstyle="<->"))
def plot_sac_arch(ax):
"""Actor-Critic: SAC (Entropy-regularized)."""
ax.axis('off')
ax.set_title("SAC Architecture", fontsize=12, fontweight='bold')
ax.text(0.5, 0.7, "Actor", bbox=dict(fc="lightgreen"), ha='center')
ax.text(0.5, 0.3, "Entropy Bonus", bbox=dict(fc="salmon"), ha='center')
ax.text(0.1, 0.5, "State", ha='center')
ax.text(0.9, 0.5, "Action", ha='center')
ax.annotate("", xy=(0.4, 0.7), xytext=(0.15, 0.5), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.5, 0.55), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.85, 0.5), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
def plot_softmax_exploration(ax):
"""Exploration: Softmax / Boltzmann probabilities."""
x = np.arange(4)
logits = [1, 2, 5, 3]
for tau in [0.5, 1.0, 5.0]:
probs = np.exp(np.array(logits)/tau)
probs /= probs.sum()
ax.plot(x, probs, marker='o', label=rf"$\tau={tau}$")
ax.set_title("Softmax Exploration", fontsize=12, fontweight='bold')
ax.legend(fontsize=8)
ax.set_xticks(x)
def plot_ucb_confidence(ax):
"""Exploration: Upper Confidence Bound (UCB)."""
actions = ['A1', 'A2', 'A3']
means = [0.6, 0.8, 0.5]
conf = [0.3, 0.1, 0.4]
ax.bar(actions, means, yerr=conf, capsize=10, color='skyblue', label='Mean Q')
ax.set_title("UCB Action Values", fontsize=12, fontweight='bold')
ax.set_ylim(0, 1.2)
def plot_intrinsic_motivation(ax):
"""Exploration: Intrinsic Motivation / Curiosity."""
ax.axis('off')
ax.set_title("Intrinsic Motivation", fontsize=12, fontweight='bold')
ax.text(0.3, 0.5, "World Model", bbox=dict(fc="lightyellow"), ha='center')
ax.text(0.7, 0.5, "Prediction\nError", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
ax.annotate("", xy=(0.58, 0.5), xytext=(0.42, 0.5), arrowprops=dict(arrowstyle="->"))
ax.text(0.85, 0.5, r"$R_{int}$", fontweight='bold')
def plot_entropy_bonus(ax):
"""Exploration: Entropy Regularization curve."""
p = np.linspace(0.01, 0.99, 50)
entropy = -(p * np.log(p) + (1-p) * np.log(1-p))
ax.plot(p, entropy, color='purple')
ax.set_title(r"Entropy $H(\pi)$", fontsize=12, fontweight='bold')
ax.set_xlabel("$P(a)$")
def plot_options_framework(ax):
"""Hierarchical RL: Options Framework."""
ax.axis('off')
ax.set_title("Options Framework", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, r"High-level policy" + "\n" + r"$\pi_{hi}$", bbox=dict(fc="lightblue"), ha='center')
ax.text(0.2, 0.2, "Option 1", bbox=dict(fc="ivory"), ha='center')
ax.text(0.8, 0.2, "Option 2", bbox=dict(fc="ivory"), ha='center')
ax.annotate("", xy=(0.3, 0.3), xytext=(0.45, 0.7), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.7, 0.3), xytext=(0.55, 0.7), arrowprops=dict(arrowstyle="->"))
def plot_feudal_networks(ax):
"""Hierarchical RL: Feudal Networks / Hierarchy."""
ax.axis('off')
ax.set_title("Feudal Networks", fontsize=12, fontweight='bold')
ax.text(0.5, 0.85, "Manager", bbox=dict(fc="plum"), ha='center')
ax.text(0.5, 0.15, "Worker", bbox=dict(fc="wheat"), ha='center')
ax.annotate("Goal $g_t$", xy=(0.5, 0.3), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
def plot_world_model(ax):
"""Model-Based RL: Learned Dynamics Model."""
ax.axis('off')
ax.set_title("World Model (Dynamics)", fontsize=12, fontweight='bold')
ax.text(0.1, 0.5, "(s,a)", ha='center')
ax.text(0.5, 0.5, r"$\hat{P}$", bbox=dict(boxstyle="circle", fc="lightgrey"), ha='center')
ax.text(0.9, 0.7, r"$\hat{s}'$", ha='center')
ax.text(0.9, 0.3, r"$\hat{r}$", ha='center')
ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.8, 0.65), xytext=(0.6, 0.55), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.8, 0.35), xytext=(0.6, 0.45), arrowprops=dict(arrowstyle="->"))
def plot_model_planning(ax):
"""Model-Based RL: Planning / Rollouts in imagination."""
ax.axis('off')
ax.set_title("Model-Based Planning", fontsize=12, fontweight='bold')
ax.text(0.1, 0.5, "Real s", ha='center', fontweight='bold')
for i in range(3):
ax.annotate("", xy=(0.3+i*0.2, 0.5+(i%2)*0.1), xytext=(0.1+i*0.2, 0.5), arrowprops=dict(arrowstyle="->", color='gray'))
ax.text(0.3+i*0.2, 0.55+(i%2)*0.1, "imagined", fontsize=7)
def plot_offline_rl(ax):
"""Offline RL: Fixed dataset of trajectories."""
ax.axis('off')
ax.set_title("Offline RL Dataset", fontsize=12, fontweight='bold')
ax.text(0.5, 0.5, r"Static" + "\n" + r"Dataset" + "\n" + r"$\mathcal{D}$", bbox=dict(boxstyle="round", fc="lightgrey"), ha='center')
ax.annotate("No interaction", xy=(0.5, 0.9), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", color='red'))
ax.scatter([0.2, 0.8, 0.3, 0.7], [0.8, 0.8, 0.2, 0.2], marker='x', color='blue')
def plot_cql_regularization(ax):
"""Offline RL: CQL regularization visualization."""
q = np.linspace(-5, 5, 100)
penalty = q**2 * 0.1
ax.plot(q, penalty, 'r', label='CQL Penalty')
ax.set_title("CQL Regularization", fontsize=12, fontweight='bold')
ax.set_xlabel("Q-value")
ax.legend(fontsize=8)
def plot_multi_agent_interaction(ax):
"""Multi-Agent RL: Agents communicating or competing."""
G = nx.complete_graph(3)
pos = nx.spring_layout(G)
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color=['red', 'blue', 'green'])
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, style='dashed')
ax.set_title("Multi-Agent Interaction", fontsize=12, fontweight='bold')
ax.axis('off')
def plot_ctde(ax):
"""Multi-Agent RL: Centralized Training Decentralized Execution (CTDE)."""
ax.axis('off')
ax.set_title("CTDE Architecture", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, "Centralized Critic", bbox=dict(fc="gold"), ha='center')
ax.text(0.2, 0.2, "Agent 1", bbox=dict(fc="lightblue"), ha='center')
ax.text(0.8, 0.2, "Agent 2", bbox=dict(fc="lightblue"), ha='center')
ax.annotate("", xy=(0.5, 0.7), xytext=(0.25, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
ax.annotate("", xy=(0.5, 0.7), xytext=(0.75, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
def plot_payoff_matrix(ax):
"""Multi-Agent RL: Cooperative / Competitive Payoff Matrix."""
matrix = np.array([[(3,3), (0,5)], [(5,0), (1,1)]])
ax.axis('off')
ax.set_title("Payoff Matrix (Prisoner's)", fontsize=12, fontweight='bold')
for i in range(2):
for j in range(2):
ax.text(j, 1-i, str(matrix[i, j]), ha='center', va='center', bbox=dict(fc="white"))
ax.set_xlim(-0.5, 1.5); ax.set_ylim(-0.5, 1.5)
def plot_irl_reward_inference(ax):
"""Inverse RL: Infer reward from expert demonstrations."""
ax.axis('off')
ax.set_title("Inferred Reward Heatmap", fontsize=12, fontweight='bold')
grid = np.zeros((5, 5))
grid[2:4, 2:4] = 1.0 # Expert path
ax.imshow(grid, cmap='hot')
def plot_gail_flow(ax):
"""Inverse RL: GAIL (Generative Adversarial Imitation Learning)."""
ax.axis('off')
ax.set_title("GAIL Architecture", fontsize=12, fontweight='bold')
ax.text(0.2, 0.8, "Expert Data", bbox=dict(fc="lightgrey"), ha='center')
ax.text(0.2, 0.2, "Policy (Gen)", bbox=dict(fc="lightgreen"), ha='center')
ax.text(0.8, 0.5, "Discriminator", bbox=dict(boxstyle="square", fc="salmon"), ha='center')
ax.annotate("", xy=(0.6, 0.55), xytext=(0.35, 0.75), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.6, 0.45), xytext=(0.35, 0.25), arrowprops=dict(arrowstyle="->"))
def plot_meta_rl_nested_loop(ax):
"""Meta-RL: Outer loop (meta) + inner loop (adaptation)."""
ax.axis('off')
ax.set_title("Meta-RL Loops", fontsize=12, fontweight='bold')
ax.add_patch(plt.Circle((0.5, 0.5), 0.4, fill=False, ls='--'))
ax.add_patch(plt.Circle((0.5, 0.5), 0.2, fill=False))
ax.text(0.5, 0.5, "Inner\nLoop", ha='center', fontsize=8)
ax.text(0.5, 0.8, "Outer Loop", ha='center', fontsize=10)
def plot_task_distribution(ax):
"""Meta-RL: Multiple MDPs from distribution."""
ax.axis('off')
ax.set_title("Task Distribution", fontsize=12, fontweight='bold')
for i in range(3):
ax.text(0.2 + i*0.3, 0.5, f"Task {i+1}", bbox=dict(boxstyle="round", fc="ivory"), fontsize=8)
ax.annotate("sample", xy=(0.5, 0.8), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="<-"))
def plot_replay_buffer(ax):
"""Advanced: Experience Replay Buffer (FIFO)."""
ax.axis('off')
ax.set_title("Experience Replay Buffer", fontsize=12, fontweight='bold')
for i in range(5):
ax.add_patch(plt.Rectangle((0.1+i*0.15, 0.4), 0.1, 0.2, fill=True, color='lightgrey'))
ax.text(0.15+i*0.15, 0.5, f"e_{i}", ha='center')
ax.annotate("In", xy=(0.05, 0.5), xytext=(-0.1, 0.5), arrowprops=dict(arrowstyle="->"), annotation_clip=False)
ax.annotate("Out (Batch)", xy=(0.85, 0.5), xytext=(1.0, 0.5), arrowprops=dict(arrowstyle="<-"), annotation_clip=False)
def plot_state_visitation(ax):
"""Advanced: State Visitation / Occupancy Measure."""
data = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], 1000)
ax.hexbin(data[:, 0], data[:, 1], gridsize=15, cmap='Blues')
ax.set_title("State Visitation Heatmap", fontsize=12, fontweight='bold')
def plot_regret_curve(ax):
"""Advanced: Regret / Cumulative Regret."""
t = np.arange(100)
regret = np.sqrt(t) + np.random.normal(0, 0.5, 100)
ax.plot(t, regret, color='red', label='Sub-linear Regret')
ax.set_title("Cumulative Regret", fontsize=12, fontweight='bold')
ax.set_xlabel("Time")
ax.legend(fontsize=8)
def plot_attention_weights(ax):
"""Advanced: Attention Mechanisms (Heatmap)."""
weights = np.random.rand(5, 5)
ax.imshow(weights, cmap='viridis')
ax.set_title("Attention Weight Matrix", fontsize=12, fontweight='bold')
ax.set_xticks([]); ax.set_yticks([])
def plot_diffusion_policy(ax):
"""Advanced: Diffusion Policy denoising steps."""
ax.axis('off')
ax.set_title("Diffusion Policy (Denoising)", fontsize=12, fontweight='bold')
for i in range(4):
ax.scatter(0.1+i*0.25, 0.5, s=100/(i+1), c='black', alpha=1.0 - i*0.2)
if i < 3: ax.annotate("", xy=(0.25+i*0.25, 0.5), xytext=(0.15+i*0.25, 0.5), arrowprops=dict(arrowstyle="->"))
ax.text(0.5, 0.3, "Noise $\\rightarrow$ Action", ha='center', fontsize=8)
def plot_gnn_rl(ax):
"""Advanced: Graph Neural Networks for RL."""
G = nx.star_graph(4)
pos = nx.spring_layout(G)
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=200, node_color='orange')
nx.draw_networkx_edges(ax=ax, G=G, pos=pos)
ax.set_title("GNN Message Passing", fontsize=12, fontweight='bold')
ax.axis('off')
def plot_latent_space(ax):
"""Advanced: World Model / Latent Space."""
ax.axis('off')
ax.set_title("Latent Space (VAE/Dreamer)", fontsize=12, fontweight='bold')
ax.text(0.1, 0.5, "Image", bbox=dict(fc="lightgrey"), ha='center')
ax.text(0.5, 0.5, "Latent $z$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
ax.text(0.9, 0.5, "Reconstruction", bbox=dict(fc="lightgrey"), ha='center')
ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
def plot_convergence_log(ax):
"""Advanced: Convergence Analysis Plots (Log-scale)."""
iterations = np.arange(1, 100)
error = 10 / iterations**2
ax.loglog(iterations, error, color='green')
ax.set_title("Value Convergence (Log)", fontsize=12, fontweight='bold')
ax.set_xlabel("Iterations")
ax.set_ylabel("Error")
ax.grid(True, which="both", ls="-", alpha=0.3)
def plot_expected_sarsa_backup(ax):
"""Temporal Difference: Expected SARSA (Expectation over policy)."""
ax.axis('off')
ax.set_title("Expected SARSA Backup", fontsize=12, fontweight='bold')
ax.text(0.5, 0.9, "(s,a)", ha='center')
ax.text(0.5, 0.1, r"$\sum_{a'} \pi(a'|s') Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="ivory"))
ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='purple'))
def plot_reinforce_flow(ax):
"""Policy Gradients: REINFORCE (Full trajectory flow)."""
ax.axis('off')
ax.set_title("REINFORCE Flow", fontsize=12, fontweight='bold')
steps = ["s0", "a0", "r1", "s1", "...", "GT"]
for i, s in enumerate(steps):
ax.text(0.1 + i*0.15, 0.5, s, bbox=dict(boxstyle="circle", fc="white"))
ax.annotate(r"$\nabla_\theta J \propto G_t \nabla \ln \pi$", xy=(0.5, 0.8), ha='center', fontsize=12, color='darkgreen')
def plot_advantage_scaled_grad(ax):
"""Policy Gradients: Baseline / Advantage scaled gradient."""
ax.axis('off')
ax.set_title("Baseline Subtraction", fontsize=12, fontweight='bold')
ax.text(0.5, 0.8, r"$(G_t - b(s))$", bbox=dict(fc="salmon"), ha='center')
ax.text(0.5, 0.3, r"Scale $\nabla \ln \pi$", ha='center')
ax.annotate("", xy=(0.5, 0.4), xytext=(0.5, 0.7), arrowprops=dict(arrowstyle="->"))
def plot_skill_discovery(ax):
"""Hierarchical RL: Skill Discovery (Unsupervised clusters)."""
np.random.seed(0)
for i in range(3):
center = np.random.randn(2) * 2
pts = np.random.randn(20, 2) * 0.5 + center
ax.scatter(pts[:, 0], pts[:, 1], alpha=0.6, label=f"Skill {i+1}")
ax.set_title("Skill Embedding Space", fontsize=12, fontweight='bold')
ax.legend(fontsize=8)
def plot_imagination_rollout(ax):
"""Model-Based RL: Imagination-Augmented Rollouts (I2A)."""
ax.axis('off')
ax.set_title("Imagination Rollout (I2A)", fontsize=12, fontweight='bold')
ax.text(0.1, 0.5, "Input s", ha='center')
ax.add_patch(plt.Rectangle((0.3, 0.3), 0.4, 0.4, fill=True, color='lavender'))
ax.text(0.5, 0.5, "Imagination\nModule", ha='center')
ax.annotate("Imagined Paths", xy=(0.8, 0.5), xytext=(0.5, 0.5), arrowprops=dict(arrowstyle="->", color='gray', connectionstyle="arc3,rad=0.3"))
def plot_policy_gradient_flow(ax):
"""Policy Gradients: Gradient flow from reward to log-prob (DAG)."""
ax.axis('off')
ax.set_title("Policy Gradient Flow (DAG)", fontsize=12, fontweight='bold')
bbox_props = dict(boxstyle="round,pad=0.5", fc="lightgrey", ec="black", lw=1.5)
ax.text(0.1, 0.8, r"Trajectory $\tau$", ha="center", va="center", bbox=bbox_props)
ax.text(0.5, 0.8, r"Reward $R(\tau)$", ha="center", va="center", bbox=bbox_props)
ax.text(0.1, 0.2, r"Log-Prob $\log \pi_\theta$", ha="center", va="center", bbox=bbox_props)
ax.text(0.7, 0.5, r"$\nabla_\theta J(\theta)$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="gold", ec="black"))
# Draw arrows
ax.annotate("", xy=(0.35, 0.8), xytext=(0.2, 0.8), arrowprops=dict(arrowstyle="->", lw=2))
ax.annotate("", xy=(0.7, 0.65), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
ax.annotate("", xy=(0.6, 0.4), xytext=(0.25, 0.2), arrowprops=dict(arrowstyle="->", lw=2))
def plot_actor_critic_arch(ax):
"""Actor-Critic: Three-network diagram (TD3 - actor + two critics)."""
ax.axis('off')
ax.set_title("TD3 Architecture Diagram", fontsize=12, fontweight='bold')
# State input
ax.text(0.1, 0.5, r"State" + "\n" + r"$s$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.5", fc="lightblue"))
# Networks
net_props = dict(boxstyle="square,pad=0.8", fc="lightgreen", ec="black")
ax.text(0.5, 0.8, r"Actor $\pi_\phi$", ha="center", va="center", bbox=net_props)
ax.text(0.5, 0.5, r"Critic 1 $Q_{\theta_1}$", ha="center", va="center", bbox=net_props)
ax.text(0.5, 0.2, r"Critic 2 $Q_{\theta_2}$", ha="center", va="center", bbox=net_props)
# Outputs
ax.text(0.8, 0.8, "Action $a$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="coral"))
ax.text(0.8, 0.35, "Min Q-value", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.3", fc="gold"))
# Connections
kwargs = dict(arrowstyle="->", lw=1.5)
ax.annotate("", xy=(0.38, 0.8), xytext=(0.15, 0.55), arrowprops=kwargs) # S -> Actor
ax.annotate("", xy=(0.38, 0.5), xytext=(0.15, 0.5), arrowprops=kwargs) # S -> C1
ax.annotate("", xy=(0.38, 0.2), xytext=(0.15, 0.45), arrowprops=kwargs) # S -> C2
ax.annotate("", xy=(0.73, 0.8), xytext=(0.62, 0.8), arrowprops=kwargs) # Actor -> Action
ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.5), arrowprops=kwargs) # C1 -> Min
ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.2), arrowprops=kwargs) # C2 -> Min
def plot_epsilon_decay(ax):
"""Exploration: ε-Greedy Strategy Decay Curve."""
episodes = np.arange(0, 1000)
epsilon = np.maximum(0.01, np.exp(-0.005 * episodes)) # Exponential decay
ax.plot(episodes, epsilon, color='purple', lw=2)
ax.set_title(r"$\epsilon$-Greedy Decay Curve", fontsize=12, fontweight='bold')
ax.set_xlabel("Episodes")
ax.set_ylabel(r"Probability $\epsilon$")
ax.grid(True, linestyle='--', alpha=0.6)
ax.fill_between(episodes, epsilon, color='purple', alpha=0.1)
def plot_learning_curve(ax):
"""Advanced / Misc: Learning Curve with Confidence Bands."""
steps = np.linspace(0, 1e6, 100)
# Simulate a learning curve converging to a maximum
mean_return = 100 * (1 - np.exp(-5e-6 * steps)) + np.random.normal(0, 2, len(steps))
std_dev = 15 * np.exp(-2e-6 * steps) # Variance decreases as policy stabilizes
ax.plot(steps, mean_return, color='blue', lw=2, label="PPO (Mean)")
ax.fill_between(steps, mean_return - std_dev, mean_return + std_dev, color='blue', alpha=0.2, label="±1 Std Dev")
ax.set_title("Learning Curve (Return vs Steps)", fontsize=12, fontweight='bold')
ax.set_xlabel("Environment Steps")
ax.set_ylabel("Average Episodic Return")
ax.legend(loc="lower right")
ax.grid(True, linestyle='--', alpha=0.6)
def main():
# Figure 1: MDP & Environment (7 plots)
fig1, gs1 = setup_figure("RL: MDP & Environment", 2, 4)
plot_agent_env_loop(fig1.add_subplot(gs1[0, 0]))
plot_mdp_graph(fig1.add_subplot(gs1[0, 1]))
plot_trajectory(fig1.add_subplot(gs1[0, 2]))
plot_continuous_space(fig1.add_subplot(gs1[0, 3]))
plot_reward_landscape(fig1, gs1) # projection='3d' handled inside
plot_discount_decay(fig1.add_subplot(gs1[1, 1]))
# row 5 (State Transition Graph) is basically plot_mdp_graph
# Layout handled by constrained_layout=True
# Figure 2: Value, Policy & Dynamic Programming
fig2, gs2 = setup_figure("RL: Value, Policy & Dynamic Programming", 2, 4)
plot_value_heatmap(fig2.add_subplot(gs2[0, 0]))
plot_action_value_q(fig2.add_subplot(gs2[0, 1]))
plot_policy_arrows(fig2.add_subplot(gs2[0, 2]))
plot_advantage_function(fig2.add_subplot(gs2[0, 3]))
plot_backup_diagram(fig2.add_subplot(gs2[1, 0])) # Policy Eval
plot_policy_improvement(fig2.add_subplot(gs2[1, 1]))
plot_value_iteration_backup(fig2.add_subplot(gs2[1, 2]))
plot_policy_iteration_cycle(fig2.add_subplot(gs2[1, 3]))
# Layout handled by constrained_layout=True
# Figure 3: Monte Carlo & Temporal Difference
fig3, gs3 = setup_figure("RL: Monte Carlo & Temporal Difference", 2, 4)
plot_mc_backup(fig3.add_subplot(gs3[0, 0]))
plot_mcts(fig3.add_subplot(gs3[0, 1]))
plot_importance_sampling(fig3.add_subplot(gs3[0, 2]))
plot_td_backup(fig3.add_subplot(gs3[0, 3]))
plot_nstep_td(fig3.add_subplot(gs3[1, 0]))
plot_eligibility_traces(fig3.add_subplot(gs3[1, 1]))
plot_sarsa_backup(fig3.add_subplot(gs3[1, 2]))
plot_q_learning_backup(fig3.add_subplot(gs3[1, 3]))
# Layout handled by constrained_layout=True
# Figure 4: TD Extensions & Function Approximation
fig4, gs4 = setup_figure("RL: TD Extensions & Function Approximation", 2, 4)
plot_double_q(fig4.add_subplot(gs4[0, 0]))
plot_dueling_dqn(fig4.add_subplot(gs4[0, 1]))
plot_prioritized_replay(fig4.add_subplot(gs4[0, 2]))
plot_rainbow_dqn(fig4.add_subplot(gs4[0, 3]))
plot_linear_fa(fig4.add_subplot(gs4[1, 0]))
plot_nn_layers(fig4.add_subplot(gs4[1, 1]))
plot_computation_graph(fig4.add_subplot(gs4[1, 2]))
plot_target_network(fig4.add_subplot(gs4[1, 3]))
# Layout handled by constrained_layout=True
# Figure 5: Policy Gradients, Actor-Critic & Exploration
fig5, gs5 = setup_figure("RL: Policy Gradients, Actor-Critic & Exploration", 2, 4)
plot_policy_gradient_flow(fig5.add_subplot(gs5[0, 0]))
plot_ppo_clip(fig5.add_subplot(gs5[0, 1]))
plot_trpo_trust_region(fig5.add_subplot(gs5[0, 2]))
plot_actor_critic_arch(fig5.add_subplot(gs5[0, 3]))
plot_a3c_multi_worker(fig5.add_subplot(gs5[1, 0]))
plot_sac_arch(fig5.add_subplot(gs5[1, 1]))
plot_softmax_exploration(fig5.add_subplot(gs5[1, 2]))
plot_ucb_confidence(fig5.add_subplot(gs5[1, 3]))
# Layout handled by constrained_layout=True
# Figure 6: Hierarchical, Model-Based & Offline RL
fig6, gs6 = setup_figure("RL: Hierarchical, Model-Based & Offline", 2, 4)
plot_options_framework(fig6.add_subplot(gs6[0, 0]))
plot_feudal_networks(fig6.add_subplot(gs6[0, 1]))
plot_world_model(fig6.add_subplot(gs6[0, 2]))
plot_model_planning(fig6.add_subplot(gs6[0, 3]))
plot_offline_rl(fig6.add_subplot(gs6[1, 0]))
plot_cql_regularization(fig6.add_subplot(gs6[1, 1]))
plot_epsilon_decay(fig6.add_subplot(gs6[1, 2])) # placeholder/spacer
plot_intrinsic_motivation(fig6.add_subplot(gs6[1, 3]))
# Layout handled by constrained_layout=True
# Figure 7: Multi-Agent, IRL & Meta-RL
fig7, gs7 = setup_figure("RL: Multi-Agent, IRL & Meta-RL", 2, 4)
plot_multi_agent_interaction(fig7.add_subplot(gs7[0, 0]))
plot_ctde(fig7.add_subplot(gs7[0, 1]))
plot_payoff_matrix(fig7.add_subplot(gs7[0, 2]))
plot_irl_reward_inference(fig7.add_subplot(gs7[0, 3]))
plot_gail_flow(fig7.add_subplot(gs7[1, 0]))
plot_meta_rl_nested_loop(fig7.add_subplot(gs7[1, 1]))
plot_task_distribution(fig7.add_subplot(gs7[1, 2]))
# Layout handled by constrained_layout=True
# Figure 8: Advanced / Miscellaneous Topics
fig8, gs8 = setup_figure("RL: Advanced & Miscellaneous", 2, 4)
plot_replay_buffer(fig8.add_subplot(gs8[0, 0]))
plot_state_visitation(fig8.add_subplot(gs8[0, 1]))
plot_regret_curve(fig8.add_subplot(gs8[0, 2]))
plot_attention_weights(fig8.add_subplot(gs8[0, 3]))
plot_diffusion_policy(fig8.add_subplot(gs8[1, 0]))
plot_gnn_rl(fig8.add_subplot(gs8[1, 1]))
plot_latent_space(fig8.add_subplot(gs8[1, 2]))
plot_convergence_log(fig8.add_subplot(gs8[1, 3]))
# Layout handled by constrained_layout=True
plt.show()
def save_all_graphs(output_dir="graphs"):
"""Saves each of the 74 RL components as a separate PNG file."""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Component-to-Function Mapping (Total 74 entries as per e.md rows)
mapping = {
"Agent-Environment Interaction Loop": plot_agent_env_loop,
"Markov Decision Process (MDP) Tuple": plot_mdp_graph,
"State Transition Graph": plot_mdp_graph,
"Trajectory / Episode Sequence": plot_trajectory,
"Continuous State/Action Space Visualization": plot_continuous_space,
"Reward Function / Landscape": plot_reward_landscape,
"Discount Factor (gamma) Effect": plot_discount_decay,
"State-Value Function V(s)": plot_value_heatmap,
"Action-Value Function Q(s,a)": plot_action_value_q,
"Policy pi(s) or pi(a|s)": plot_policy_arrows,
"Advantage Function A(s,a)": plot_advantage_function,
"Optimal Value Function V* / Q*": plot_value_heatmap,
"Policy Evaluation Backup": plot_backup_diagram,
"Policy Improvement": plot_policy_improvement,
"Value Iteration Backup": plot_value_iteration_backup,
"Policy Iteration Full Cycle": plot_policy_iteration_cycle,
"Monte Carlo Backup": plot_mc_backup,
"Monte Carlo Tree (MCTS)": plot_mcts,
"Importance Sampling Ratio": plot_importance_sampling,
"TD(0) Backup": plot_td_backup,
"Bootstrapping (general)": plot_td_backup,
"n-step TD Backup": plot_nstep_td,
"TD(lambda) & Eligibility Traces": plot_eligibility_traces,
"SARSA Update": plot_sarsa_backup,
"Q-Learning Update": plot_q_learning_backup,
"Expected SARSA": plot_expected_sarsa_backup,
"Double Q-Learning / Double DQN": plot_double_q,
"Dueling DQN Architecture": plot_dueling_dqn,
"Prioritized Experience Replay": plot_prioritized_replay,
"Rainbow DQN Components": plot_rainbow_dqn,
"Linear Function Approximation": plot_linear_fa,
"Neural Network Layers (MLP, CNN, RNN, Transformer)": plot_nn_layers,
"Computation Graph / Backpropagation Flow": plot_computation_graph,
"Target Network": plot_target_network,
"Policy Gradient Theorem": plot_policy_gradient_flow,
"REINFORCE Update": plot_reinforce_flow,
"Baseline / Advantage Subtraction": plot_advantage_scaled_grad,
"Trust Region (TRPO)": plot_trpo_trust_region,
"Proximal Policy Optimization (PPO)": plot_ppo_clip,
"Actor-Critic Architecture": plot_actor_critic_arch,
"Advantage Actor-Critic (A2C/A3C)": plot_a3c_multi_worker,
"Soft Actor-Critic (SAC)": plot_sac_arch,
"Twin Delayed DDPG (TD3)": plot_actor_critic_arch,
"epsilon-Greedy Strategy": plot_epsilon_decay,
"Softmax / Boltzmann Exploration": plot_softmax_exploration,
"Upper Confidence Bound (UCB)": plot_ucb_confidence,
"Intrinsic Motivation / Curiosity": plot_intrinsic_motivation,
"Entropy Regularization": plot_entropy_bonus,
"Options Framework": plot_options_framework,
"Feudal Networks / Hierarchical Actor-Critic": plot_feudal_networks,
"Skill Discovery": plot_skill_discovery,
"Learned Dynamics Model": plot_world_model,
"Model-Based Planning": plot_model_planning,
"Imagination-Augmented Agents (I2A)": plot_imagination_rollout,
"Offline Dataset": plot_offline_rl,
"Conservative Q-Learning (CQL)": plot_cql_regularization,
"Multi-Agent Interaction Graph": plot_multi_agent_interaction,
"Centralized Training Decentralized Execution (CTDE)": plot_ctde,
"Cooperative / Competitive Payoff Matrix": plot_payoff_matrix,
"Reward Inference": plot_irl_reward_inference,
"Generative Adversarial Imitation Learning (GAIL)": plot_gail_flow,
"Meta-RL Architecture": plot_meta_rl_nested_loop,
"Task Distribution Visualization": plot_task_distribution,
"Experience Replay Buffer": plot_replay_buffer,
"State Visitation / Occupancy Measure": plot_state_visitation,
"Learning Curve": plot_learning_curve,
"Regret / Cumulative Regret": plot_regret_curve,
"Attention Mechanisms (Transformers in RL)": plot_attention_weights,
"Diffusion Policy": plot_diffusion_policy,
"Graph Neural Networks for RL": plot_gnn_rl,
"World Model / Latent Space": plot_latent_space,
"Convergence Analysis Plots": plot_convergence_log
}
import sys
for name, func in mapping.items():
# Sanitize filename
filename = re.sub(r'[^a-zA-Z0-9]', '_', name.lower()).strip('_')
filename = re.sub(r'_+', '_', filename) + ".png"
filepath = os.path.join(output_dir, filename)
print(f"Generating: {filename} ...")
plt.close('all')
if func == plot_reward_landscape:
fig = plt.figure(figsize=(10, 8))
gs = GridSpec(1, 1, figure=fig)
func(fig, gs)
plt.savefig(filepath, bbox_inches='tight', dpi=100)
plt.close(fig)
continue
fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True)
func(ax)
plt.savefig(filepath, bbox_inches='tight', dpi=100)
plt.close(fig)
print(f"\n[SUCCESS] Saved {len(mapping)} graphs to '{output_dir}/' directory.")
if __name__ == "__main__":
import sys
if "--save" in sys.argv:
save_all_graphs()
else:
main()