algorembrant commited on
Commit
8744e5e
·
verified ·
1 Parent(s): d0d4f5d

Upload 76 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +80 -3
  3. core.py +977 -0
  4. graphs/action_value_function_q_s_a.png +0 -0
  5. graphs/actor_critic_architecture.png +0 -0
  6. graphs/advantage_actor_critic_a2c_a3c.png +0 -0
  7. graphs/advantage_function_a_s_a.png +0 -0
  8. graphs/agent_environment_interaction_loop.png +0 -0
  9. graphs/attention_mechanisms_transformers_in_rl.png +0 -0
  10. graphs/baseline_advantage_subtraction.png +0 -0
  11. graphs/bootstrapping_general.png +0 -0
  12. graphs/centralized_training_decentralized_execution_ctde.png +0 -0
  13. graphs/computation_graph_backpropagation_flow.png +0 -0
  14. graphs/conservative_q_learning_cql.png +0 -0
  15. graphs/continuous_state_action_space_visualization.png +0 -0
  16. graphs/convergence_analysis_plots.png +0 -0
  17. graphs/cooperative_competitive_payoff_matrix.png +0 -0
  18. graphs/diffusion_policy.png +0 -0
  19. graphs/discount_factor_gamma_effect.png +0 -0
  20. graphs/double_q_learning_double_dqn.png +0 -0
  21. graphs/dueling_dqn_architecture.png +0 -0
  22. graphs/entropy_regularization.png +0 -0
  23. graphs/epsilon_greedy_strategy.png +0 -0
  24. graphs/expected_sarsa.png +0 -0
  25. graphs/experience_replay_buffer.png +0 -0
  26. graphs/feudal_networks_hierarchical_actor_critic.png +0 -0
  27. graphs/generative_adversarial_imitation_learning_gail.png +0 -0
  28. graphs/graph_neural_networks_for_rl.png +0 -0
  29. graphs/imagination_augmented_agents_i2a.png +0 -0
  30. graphs/importance_sampling_ratio.png +0 -0
  31. graphs/intrinsic_motivation_curiosity.png +0 -0
  32. graphs/learned_dynamics_model.png +0 -0
  33. graphs/learning_curve.png +0 -0
  34. graphs/linear_function_approximation.png +0 -0
  35. graphs/markov_decision_process_mdp_tuple.png +0 -0
  36. graphs/meta_rl_architecture.png +0 -0
  37. graphs/model_based_planning.png +0 -0
  38. graphs/monte_carlo_backup.png +0 -0
  39. graphs/monte_carlo_tree_mcts.png +0 -0
  40. graphs/multi_agent_interaction_graph.png +0 -0
  41. graphs/n_step_td_backup.png +0 -0
  42. graphs/neural_network_layers_mlp_cnn_rnn_transformer.png +0 -0
  43. graphs/offline_dataset.png +0 -0
  44. graphs/optimal_value_function_v_q.png +0 -0
  45. graphs/options_framework.png +0 -0
  46. graphs/policy_evaluation_backup.png +0 -0
  47. graphs/policy_gradient_theorem.png +0 -0
  48. graphs/policy_improvement.png +0 -0
  49. graphs/policy_iteration_full_cycle.png +0 -0
  50. graphs/policy_pi_s_or_pi_a_s.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ graphs/reward_function_landscape.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,80 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---title: Reinforcement Learning Graphical Representationsdate: 2026-04-08category: Reinforcement Learningdescription: A comprehensive gallery of 72 standard RL components and their graphical presentations.---
2
+
3
+ # Reinforcement Learning Graphical Representations
4
+
5
+ This repository contains a full set of 72 visualizations representing foundational concepts, algorithms, and advanced topics in Reinforcement Learning.
6
+
7
+ | Category | Component | Illustration | Details | Context |
8
+ |----------|-----------|--------------|---------|---------|
9
+ | **MDP & Environment** | **Agent-Environment Interaction Loop** | ![Illustration](graphs/agent_environment_interaction_loop.png) | Core cycle: observation of state → selection of action → environment transition → receipt of reward + next state | All RL algorithms |
10
+ | **MDP & Environment** | **Markov Decision Process (MDP) Tuple** | ![Illustration](graphs/markov_decision_process_mdp_tuple.png) | (S, A, P, R, γ) with transition dynamics and reward function | s,a) and R(s,a,s′)) |
11
+ | **MDP & Environment** | **State Transition Graph** | ![Illustration](graphs/state_transition_graph.png) | Full probabilistic transitions between discrete states | Gridworld, Taxi, Cliff Walking |
12
+ | **MDP & Environment** | **Trajectory / Episode Sequence** | ![Illustration](graphs/trajectory_episode_sequence.png) | Sequence of (s₀, a₀, r₁, s₁, …, s_T) | Monte Carlo, episodic tasks |
13
+ | **MDP & Environment** | **Continuous State/Action Space Visualization** | ![Illustration](graphs/continuous_state_action_space_visualization.png) | High-dimensional spaces (e.g., robot joints, pixel inputs) | Continuous-control tasks (MuJoCo, PyBullet) |
14
+ | **MDP & Environment** | **Reward Function / Landscape** | ![Illustration](graphs/reward_function_landscape.png) | Scalar reward as function of state/action | All algorithms; especially reward shaping |
15
+ | **MDP & Environment** | **Discount Factor (γ) Effect** | ![Illustration](graphs/discount_factor_effect.png) | How future rewards are weighted | All discounted MDPs |
16
+ | **Value & Policy** | **State-Value Function V(s)** | ![Illustration](graphs/state_value_function_v_s.png) | Expected return from state s under policy π | Value-based methods |
17
+ | **Value & Policy** | **Action-Value Function Q(s,a)** | ![Illustration](graphs/action_value_function_q_s_a.png) | Expected return from state-action pair | Q-learning family |
18
+ | **Value & Policy** | **Policy π(s) or π(a\** | ![Illustration](graphs/policy_s_or_a.png) | s) | Arrow overlays on grid (optimal policy), probability bar charts, or softmax heatmaps |
19
+ | **Value & Policy** | **Advantage Function A(s,a)** | ![Illustration](graphs/advantage_function_a_s_a.png) | Q(s,a) – V(s) | A2C, PPO, SAC, TD3 |
20
+ | **Value & Policy** | **Optimal Value Function V* / Q*** | ![Illustration](graphs/optimal_value_function_v_q.png) | Solution to Bellman optimality | Value iteration, Q-learning |
21
+ | **Dynamic Programming** | **Policy Evaluation Backup** | ![Illustration](graphs/policy_evaluation_backup.png) | Iterative update of V using Bellman expectation | Policy iteration |
22
+ | **Dynamic Programming** | **Policy Improvement** | ![Illustration](graphs/policy_improvement.png) | Greedy policy update over Q | Policy iteration |
23
+ | **Dynamic Programming** | **Value Iteration Backup** | ![Illustration](graphs/value_iteration_backup.png) | Update using Bellman optimality | Value iteration |
24
+ | **Dynamic Programming** | **Policy Iteration Full Cycle** | ![Illustration](graphs/policy_iteration_full_cycle.png) | Evaluation → Improvement loop | Classic DP methods |
25
+ | **Monte Carlo** | **Monte Carlo Backup** | ![Illustration](graphs/monte_carlo_backup.png) | Update using full episode return G_t | First-visit / every-visit MC |
26
+ | **Monte Carlo** | **Monte Carlo Tree (MCTS)** | ![Illustration](graphs/monte_carlo_tree_mcts.png) | Search tree with selection, expansion, simulation, backprop | AlphaGo, AlphaZero |
27
+ | **Monte Carlo** | **Importance Sampling Ratio** | ![Illustration](graphs/importance_sampling_ratio.png) | Off-policy correction ρ = π(a\ | s) |
28
+ | **Temporal Difference** | **TD(0) Backup** | ![Illustration](graphs/td_0_backup.png) | Bootstrapped update using R + γV(s′) | TD learning |
29
+ | **Temporal Difference** | **Bootstrapping (general)** | ![Illustration](graphs/bootstrapping_general.png) | Using estimated future value instead of full return | All TD methods |
30
+ | **Temporal Difference** | **n-step TD Backup** | ![Illustration](graphs/n_step_td_backup.png) | Multi-step return G_t^{(n)} | n-step TD, TD(λ) |
31
+ | **Temporal Difference** | **TD(λ) & Eligibility Traces** | ![Illustration](graphs/td_eligibility_traces.png) | Decaying trace z_t for credit assignment | TD(λ), SARSA(λ), Q(λ) |
32
+ | **Temporal Difference** | **SARSA Update** | ![Illustration](graphs/sarsa_update.png) | On-policy TD control | SARSA |
33
+ | **Temporal Difference** | **Q-Learning Update** | ![Illustration](graphs/q_learning_update.png) | Off-policy TD control | Q-learning, Deep Q-Network |
34
+ | **Temporal Difference** | **Expected SARSA** | ![Illustration](graphs/expected_sarsa.png) | Expectation over next action under policy | Expected SARSA |
35
+ | **Temporal Difference** | **Double Q-Learning / Double DQN** | ![Illustration](graphs/double_q_learning_double_dqn.png) | Two separate Q estimators to reduce overestimation | Double DQN, TD3 |
36
+ | **Temporal Difference** | **Dueling DQN Architecture** | ![Illustration](graphs/dueling_dqn_architecture.png) | Separate streams for state value V(s) and advantage A(s,a) | Dueling DQN |
37
+ | **Temporal Difference** | **Prioritized Experience Replay** | ![Illustration](graphs/prioritized_experience_replay.png) | Importance sampling of transitions by TD error | Prioritized DQN, Rainbow |
38
+ | **Temporal Difference** | **Rainbow DQN Components** | ![Illustration](graphs/rainbow_dqn_components.png) | All extensions combined (Double, Dueling, PER, etc.) | Rainbow DQN |
39
+ | **Function Approximation** | **Linear Function Approximation** | ![Illustration](graphs/linear_function_approximation.png) | Feature vector φ(s) → wᵀφ(s) | Tabular → linear FA |
40
+ | **Function Approximation** | **Neural Network Layers (MLP, CNN, RNN, Transformer)** | ![Illustration](graphs/neural_network_layers_mlp_cnn_rnn_transformer.png) | Full deep network for value/policy | DQN, A3C, PPO, Decision Transformer |
41
+ | **Function Approximation** | **Computation Graph / Backpropagation Flow** | ![Illustration](graphs/computation_graph_backpropagation_flow.png) | Gradient flow through network | All deep RL |
42
+ | **Function Approximation** | **Target Network** | ![Illustration](graphs/target_network.png) | Frozen copy of Q-network for stability | DQN, DDQN, SAC, TD3 |
43
+ | **Policy Gradients** | **Policy Gradient Theorem** | ![Illustration](graphs/policy_gradient_theorem.png) | ∇_θ J(θ) = E[∇_θ log π(a\ | Flow diagram from reward → log-prob → gradient |
44
+ | **Policy Gradients** | **REINFORCE Update** | ![Illustration](graphs/reinforce_update.png) | Monte-Carlo policy gradient | REINFORCE |
45
+ | **Policy Gradients** | **Baseline / Advantage Subtraction** | ![Illustration](graphs/baseline_advantage_subtraction.png) | Subtract b(s) to reduce variance | All modern PG |
46
+ | **Policy Gradients** | **Trust Region (TRPO)** | ![Illustration](graphs/trust_region_trpo.png) | KL-divergence constraint on policy update | TRPO |
47
+ | **Policy Gradients** | **Proximal Policy Optimization (PPO)** | ![Illustration](graphs/proximal_policy_optimization_ppo.png) | Clipped surrogate objective | PPO, PPO-Clip |
48
+ | **Actor-Critic** | **Actor-Critic Architecture** | ![Illustration](graphs/actor_critic_architecture.png) | Separate or shared actor (policy) + critic (value) networks | A2C, A3C, SAC, TD3 |
49
+ | **Actor-Critic** | **Advantage Actor-Critic (A2C/A3C)** | ![Illustration](graphs/advantage_actor_critic_a2c_a3c.png) | Synchronous/asynchronous multi-worker | A2C/A3C |
50
+ | **Actor-Critic** | **Soft Actor-Critic (SAC)** | ![Illustration](graphs/soft_actor_critic_sac.png) | Entropy-regularized policy + twin critics | SAC |
51
+ | **Actor-Critic** | **Twin Delayed DDPG (TD3)** | ![Illustration](graphs/twin_delayed_ddpg_td3.png) | Twin critics + delayed policy + target smoothing | TD3 |
52
+ | **Exploration** | **ε-Greedy Strategy** | ![Illustration](graphs/greedy_strategy.png) | Probability ε of random action | DQN family |
53
+ | **Exploration** | **Softmax / Boltzmann Exploration** | ![Illustration](graphs/softmax_boltzmann_exploration.png) | Temperature τ in softmax | Softmax policies |
54
+ | **Exploration** | **Upper Confidence Bound (UCB)** | ![Illustration](graphs/upper_confidence_bound_ucb.png) | Optimism in face of uncertainty | UCB1, bandits |
55
+ | **Exploration** | **Intrinsic Motivation / Curiosity** | ![Illustration](graphs/intrinsic_motivation_curiosity.png) | Prediction error as intrinsic reward | ICM, RND, Curiosity-driven RL |
56
+ | **Exploration** | **Entropy Regularization** | ![Illustration](graphs/entropy_regularization.png) | Bonus term αH(π) | SAC, maximum-entropy RL |
57
+ | **Hierarchical RL** | **Options Framework** | ![Illustration](graphs/options_framework.png) | High-level policy over options (temporally extended actions) | Option-Critic |
58
+ | **Hierarchical RL** | **Feudal Networks / Hierarchical Actor-Critic** | ![Illustration](graphs/feudal_networks_hierarchical_actor_critic.png) | Manager-worker hierarchy | Feudal RL |
59
+ | **Hierarchical RL** | **Skill Discovery** | ![Illustration](graphs/skill_discovery.png) | Unsupervised emergence of reusable skills | DIAYN, VALOR |
60
+ | **Model-Based RL** | **Learned Dynamics Model** | ![Illustration](graphs/learned_dynamics_model.png) | ˆP(s′\ | Separate model network diagram (often RNN or transformer) |
61
+ | **Model-Based RL** | **Model-Based Planning** | ![Illustration](graphs/model_based_planning.png) | Rollouts inside learned model | MuZero, DreamerV3 |
62
+ | **Model-Based RL** | **Imagination-Augmented Agents (I2A)** | ![Illustration](graphs/imagination_augmented_agents_i2a.png) | Imagination module + policy | I2A |
63
+ | **Offline RL** | **Offline Dataset** | ![Illustration](graphs/offline_dataset.png) | Fixed batch of trajectories | BC, CQL, IQL |
64
+ | **Offline RL** | **Conservative Q-Learning (CQL)** | ![Illustration](graphs/conservative_q_learning_cql.png) | Penalty on out-of-distribution actions | CQL |
65
+ | **Multi-Agent RL** | **Multi-Agent Interaction Graph** | ![Illustration](graphs/multi_agent_interaction_graph.png) | Agents communicating or competing | MARL, MADDPG |
66
+ | **Multi-Agent RL** | **Centralized Training Decentralized Execution (CTDE)** | ![Illustration](graphs/centralized_training_decentralized_execution_ctde.png) | Shared critic during training | QMIX, VDN, MADDPG |
67
+ | **Multi-Agent RL** | **Cooperative / Competitive Payoff Matrix** | ![Illustration](graphs/cooperative_competitive_payoff_matrix.png) | Joint reward for multiple agents | Prisoner's Dilemma, multi-agent gridworlds |
68
+ | **Inverse RL / IRL** | **Reward Inference** | ![Illustration](graphs/reward_inference.png) | Infer reward from expert demonstrations | IRL, GAIL |
69
+ | **Inverse RL / IRL** | **Generative Adversarial Imitation Learning (GAIL)** | ![Illustration](graphs/generative_adversarial_imitation_learning_gail.png) | Discriminator vs. policy generator | GAIL, AIRL |
70
+ | **Meta-RL** | **Meta-RL Architecture** | ![Illustration](graphs/meta_rl_architecture.png) | Outer loop (meta-policy) + inner loop (task adaptation) | MAML for RL, RL² |
71
+ | **Meta-RL** | **Task Distribution Visualization** | ![Illustration](graphs/task_distribution_visualization.png) | Multiple MDPs sampled from meta-distribution | Meta-RL benchmarks |
72
+ | **Advanced / Misc** | **Experience Replay Buffer** | ![Illustration](graphs/experience_replay_buffer.png) | Stored (s,a,r,s′,done) tuples | DQN and all off-policy deep RL |
73
+ | **Advanced / Misc** | **State Visitation / Occupancy Measure** | ![Illustration](graphs/state_visitation_occupancy_measure.png) | Frequency of visiting each state | All algorithms (analysis) |
74
+ | **Advanced / Misc** | **Learning Curve** | ![Illustration](graphs/learning_curve.png) | Average episodic return vs. episodes / steps | Standard performance reporting |
75
+ | **Advanced / Misc** | **Regret / Cumulative Regret** | ![Illustration](graphs/regret_cumulative_regret.png) | Sub-optimality accumulated | Bandits and online RL |
76
+ | **Advanced / Misc** | **Attention Mechanisms (Transformers in RL)** | ![Illustration](graphs/attention_mechanisms_transformers_in_rl.png) | Attention weights | Decision Transformer, Trajectory Transformer |
77
+ | **Advanced / Misc** | **Diffusion Policy** | ![Illustration](graphs/diffusion_policy.png) | Denoising diffusion process for action generation | Diffusion-RL policies |
78
+ | **Advanced / Misc** | **Graph Neural Networks for RL** | ![Illustration](graphs/graph_neural_networks_for_rl.png) | Node/edge message passing | Graph RL, relational RL |
79
+ | **Advanced / Misc** | **World Model / Latent Space** | ![Illustration](graphs/world_model_latent_space.png) | Encoder-decoder dynamics in latent space | Dreamer, PlaNet |
80
+ | **Advanced / Misc** | **Convergence Analysis Plots** | ![Illustration](graphs/convergence_analysis_plots.png) | Error / value change over iterations | DP, TD, value iteration |
core.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import networkx as nx
4
+ from matplotlib.gridspec import GridSpec
5
+ from matplotlib.patches import FancyArrowPatch
6
+
7
+ import os
8
+ import re
9
+
10
+ def setup_figure(title, rows, cols):
11
+ """Initializes a new figure and grid layout with constrained_layout to avoid warnings."""
12
+ fig = plt.figure(figsize=(20, 10), constrained_layout=True)
13
+ fig.suptitle(title, fontsize=18, fontweight='bold')
14
+ gs = GridSpec(rows, cols, figure=fig)
15
+ return fig, gs
16
+
17
+ def plot_agent_env_loop(ax):
18
+ """MDP & Environment: Agent-Environment Interaction Loop (Flowchart)."""
19
+ ax.axis('off')
20
+ ax.set_title("Agent-Environment Interaction", fontsize=12, fontweight='bold')
21
+
22
+ props = dict(boxstyle="round,pad=0.8", fc="ivory", ec="black", lw=1.5)
23
+ ax.text(0.5, 0.8, "Agent", ha="center", va="center", bbox=props, fontsize=12)
24
+ ax.text(0.5, 0.2, "Environment", ha="center", va="center", bbox=props, fontsize=12)
25
+
26
+ # Arrows
27
+ # Agent to Env: Action
28
+ ax.annotate("Action $A_t$", xy=(0.5, 0.35), xytext=(0.5, 0.65),
29
+ arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2))
30
+ # Env to Agent: State & Reward
31
+ ax.annotate("State $S_{t+1}$, Reward $R_{t+1}$", xy=(0.5, 0.65), xytext=(0.5, 0.35),
32
+ arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2, color='green'))
33
+
34
+ def plot_mdp_graph(ax):
35
+ """MDP & Environment: Directed graph with probability-weighted arrows."""
36
+ G = nx.DiGraph()
37
+ # Corrected syntax: using a dictionary for edge attributes
38
+ G.add_edges_from([
39
+ ('S0', 'S1', {'weight': 0.8}), ('S0', 'S2', {'weight': 0.2}),
40
+ ('S1', 'S2', {'weight': 1.0}), ('S2', 'S0', {'weight': 0.5}), ('S2', 'S2', {'weight': 0.5})
41
+ ])
42
+ pos = nx.spring_layout(G, seed=42)
43
+ nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=1500, node_color='lightblue')
44
+ nx.draw_networkx_labels(ax=ax, G=G, pos=pos, font_weight='bold')
45
+
46
+ edge_labels = {(u, v): f"P={d['weight']}" for u, v, d in G.edges(data=True)}
47
+ nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrowsize=20, edge_color='gray', connectionstyle="arc3,rad=0.1")
48
+ nx.draw_networkx_edge_labels(ax=ax, G=G, pos=pos, edge_labels=edge_labels, font_size=9)
49
+ ax.set_title("MDP State Transition Graph", fontsize=12, fontweight='bold')
50
+ ax.axis('off')
51
+
52
+ def plot_reward_landscape(fig, gs):
53
+ """MDP & Environment: 3D surface plot of a reward function."""
54
+ # Use the first available slot in gs (handled flexibly for dashboard vs save)
55
+ try:
56
+ ax = fig.add_subplot(gs[0, 1], projection='3d')
57
+ except IndexError:
58
+ ax = fig.add_subplot(gs[0, 0], projection='3d')
59
+ X = np.linspace(-5, 5, 50)
60
+ Y = np.linspace(-5, 5, 50)
61
+ X, Y = np.meshgrid(X, Y)
62
+ Z = np.sin(np.sqrt(X**2 + Y**2)) + (X * 0.1) # Simulated reward landscape
63
+
64
+ surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.9)
65
+ ax.set_title("Reward Function Landscape", fontsize=12, fontweight='bold')
66
+ ax.set_xlabel('State X')
67
+ ax.set_ylabel('State Y')
68
+ ax.set_zlabel('Reward R(s)')
69
+
70
+ def plot_trajectory(ax):
71
+ """MDP & Environment: Trajectory / Episode Sequence."""
72
+ ax.set_title("Trajectory Sequence", fontsize=12, fontweight='bold')
73
+ states = ['s0', 's1', 's2', 's3', 'sT']
74
+ actions = ['a0', 'a1', 'a2', 'a3']
75
+ rewards = ['r1', 'r2', 'r3', 'r4']
76
+
77
+ for i, s in enumerate(states):
78
+ ax.text(i, 0.5, s, ha='center', va='center', bbox=dict(boxstyle="circle", fc="white"))
79
+ if i < len(actions):
80
+ ax.annotate("", xy=(i+0.8, 0.5), xytext=(i+0.2, 0.5), arrowprops=dict(arrowstyle="->"))
81
+ ax.text(i+0.5, 0.6, actions[i], ha='center', color='blue')
82
+ ax.text(i+0.5, 0.4, rewards[i], ha='center', color='red')
83
+
84
+ ax.set_xlim(-0.5, len(states)-0.5)
85
+ ax.set_ylim(0, 1)
86
+ ax.axis('off')
87
+
88
+ def plot_continuous_space(ax):
89
+ """MDP & Environment: Continuous State/Action Space Visualization."""
90
+ np.random.seed(42)
91
+ x = np.random.randn(200, 2)
92
+ labels = np.linalg.norm(x, axis=1) > 1.0
93
+ ax.scatter(x[labels, 0], x[labels, 1], c='coral', alpha=0.6, label='High Reward')
94
+ ax.scatter(x[~labels, 0], x[~labels, 1], c='skyblue', alpha=0.6, label='Low Reward')
95
+ ax.set_title("Continuous State Space (2D Projection)", fontsize=12, fontweight='bold')
96
+ ax.legend(fontsize=8)
97
+
98
+ def plot_discount_decay(ax):
99
+ """MDP & Environment: Discount Factor (gamma) Effect."""
100
+ t = np.arange(0, 20)
101
+ for gamma in [0.5, 0.9, 0.99]:
102
+ ax.plot(t, gamma**t, marker='o', markersize=4, label=rf"$\gamma={gamma}$")
103
+ ax.set_title(r"Discount Factor $\gamma^t$ Decay", fontsize=12, fontweight='bold')
104
+ ax.set_xlabel("Time steps (t)")
105
+ ax.set_ylabel("Weight")
106
+ ax.legend()
107
+ ax.grid(True, alpha=0.3)
108
+
109
+ def plot_value_heatmap(ax):
110
+ """Value & Policy: State-Value Function V(s) Heatmap (Gridworld)."""
111
+ grid_size = 5
112
+ # Simulate a value landscape where the top right is the goal
113
+ values = np.zeros((grid_size, grid_size))
114
+ for i in range(grid_size):
115
+ for j in range(grid_size):
116
+ values[i, j] = -( (grid_size-1-i)**2 + (grid_size-1-j)**2 ) * 0.5
117
+ values[-1, -1] = 10.0 # Goal state
118
+
119
+ cax = ax.matshow(values, cmap='magma')
120
+ for (i, j), z in np.ndenumerate(values):
121
+ ax.text(j, i, f'{z:0.1f}', ha='center', va='center', color='white' if z < -5 else 'black', fontsize=9)
122
+
123
+ ax.set_title("State-Value Function V(s) Heatmap", fontsize=12, fontweight='bold', pad=15)
124
+ ax.set_xticks(range(grid_size))
125
+ ax.set_yticks(range(grid_size))
126
+
127
+ def plot_backup_diagram(ax):
128
+ """Dynamic Programming: Policy Evaluation Backup Diagram."""
129
+ G = nx.DiGraph()
130
+ G.add_node("s", layer=0)
131
+ G.add_node("a1", layer=1); G.add_node("a2", layer=1)
132
+ G.add_node("s'_1", layer=2); G.add_node("s'_2", layer=2); G.add_node("s'_3", layer=2)
133
+
134
+ G.add_edges_from([("s", "a1"), ("s", "a2")])
135
+ G.add_edges_from([("a1", "s'_1"), ("a1", "s'_2"), ("a2", "s'_3")])
136
+
137
+ pos = {
138
+ "s": (0.5, 1),
139
+ "a1": (0.25, 0.5), "a2": (0.75, 0.5),
140
+ "s'_1": (0.1, 0), "s'_2": (0.4, 0), "s'_3": (0.75, 0)
141
+ }
142
+
143
+ nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["s", "s'_1", "s'_2", "s'_3"], node_size=800, node_color='white', edgecolors='black')
144
+ nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["a1", "a2"], node_size=300, node_color='black') # Action nodes are solid black dots
145
+ nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
146
+ nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "s'_1": "s'", "s'_2": "s'", "s'_3": "s'"}, font_size=10)
147
+
148
+ ax.set_title("DP Policy Eval Backup", fontsize=12, fontweight='bold')
149
+ ax.set_ylim(-0.2, 1.2)
150
+ ax.axis('off')
151
+
152
+ def plot_action_value_q(ax):
153
+ """Value & Policy: Action-Value Function Q(s,a) (Heatmap per action stack)."""
154
+ grid = np.random.rand(3, 3)
155
+ ax.imshow(grid, cmap='YlGnBu')
156
+ for (i, j), z in np.ndenumerate(grid):
157
+ ax.text(j, i, f'{z:0.1f}', ha='center', va='center', fontsize=8)
158
+ ax.set_title(r"Action-Value $Q(s, a_{up})$", fontsize=12, fontweight='bold')
159
+ ax.set_xticks([]); ax.set_yticks([])
160
+
161
+ def plot_policy_arrows(ax):
162
+ """Value & Policy: Policy π(s) as arrow overlays on grid."""
163
+ grid_size = 4
164
+ ax.set_xlim(-0.5, grid_size-0.5)
165
+ ax.set_ylim(-0.5, grid_size-0.5)
166
+ for i in range(grid_size):
167
+ for j in range(grid_size):
168
+ dx, dy = np.random.choice([0, 0.3, -0.3]), np.random.choice([0, 0.3, -0.3])
169
+ if dx == 0 and dy == 0: dx = 0.3
170
+ ax.add_patch(FancyArrowPatch((j, i), (j+dx, i+dy), arrowstyle='->', mutation_scale=15))
171
+ ax.set_title(r"Policy $\pi(s)$ Arrows", fontsize=12, fontweight='bold')
172
+ ax.set_xticks(range(grid_size)); ax.set_yticks(range(grid_size)); ax.grid(True, alpha=0.2)
173
+
174
+ def plot_advantage_function(ax):
175
+ """Value & Policy: Advantage Function A(s,a) = Q-V."""
176
+ actions = ['A1', 'A2', 'A3', 'A4']
177
+ advantage = [2.1, -1.2, 0.5, -0.8]
178
+ colors = ['green' if v > 0 else 'red' for v in advantage]
179
+ ax.bar(actions, advantage, color=colors, alpha=0.7)
180
+ ax.axhline(0, color='black', lw=1)
181
+ ax.set_title(r"Advantage $A(s, a)$", fontsize=12, fontweight='bold')
182
+ ax.set_ylabel("Value")
183
+
184
+ def plot_policy_improvement(ax):
185
+ """Dynamic Programming: Policy Improvement (Before vs After)."""
186
+ ax.axis('off')
187
+ ax.set_title("Policy Improvement", fontsize=12, fontweight='bold')
188
+ ax.text(0.2, 0.5, r"$\pi_{old}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgrey"))
189
+ ax.annotate("", xy=(0.8, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", lw=2))
190
+ ax.text(0.5, 0.6, "Greedy\nImprovement", ha='center', fontsize=9)
191
+ ax.text(0.85, 0.5, r"$\pi_{new}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgreen"))
192
+
193
+ def plot_value_iteration_backup(ax):
194
+ """Dynamic Programming: Value Iteration Backup Diagram (Max over actions)."""
195
+ G = nx.DiGraph()
196
+ pos = {"s": (0.5, 1), "max": (0.5, 0.5), "s1": (0.2, 0), "s2": (0.5, 0), "s3": (0.8, 0)}
197
+ G.add_nodes_from(pos.keys())
198
+ G.add_edges_from([("s", "max"), ("max", "s1"), ("max", "s2"), ("max", "s3")])
199
+
200
+ nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color='white', edgecolors='black')
201
+ nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
202
+ nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "max": "max", "s1": "s'", "s2": "s'", "s3": "s'"}, font_size=9)
203
+ ax.set_title("Value Iteration Backup", fontsize=12, fontweight='bold')
204
+ ax.axis('off')
205
+
206
+ def plot_policy_iteration_cycle(ax):
207
+ """Dynamic Programming: Policy Iteration Full Cycle Flowchart."""
208
+ ax.axis('off')
209
+ ax.set_title("Policy Iteration Cycle", fontsize=12, fontweight='bold')
210
+ props = dict(boxstyle="round", fc="aliceblue", ec="black")
211
+ ax.text(0.5, 0.8, r"Policy Evaluation" + "\n" + r"$V \leftarrow V^\pi$", ha="center", bbox=props)
212
+ ax.text(0.5, 0.2, r"Policy Improvement" + "\n" + r"$\pi \leftarrow \text{greedy}(V)$", ha="center", bbox=props)
213
+ ax.annotate("", xy=(0.7, 0.3), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
214
+ ax.annotate("", xy=(0.3, 0.7), xytext=(0.3, 0.3), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
215
+
216
+ def plot_mc_backup(ax):
217
+ """Monte Carlo: Backup diagram (Full trajectory until terminal sT)."""
218
+ ax.axis('off')
219
+ ax.set_title("Monte Carlo Backup", fontsize=12, fontweight='bold')
220
+ nodes = ['s', 's1', 's2', 'sT']
221
+ pos = {n: (0.5, 0.9 - i*0.25) for i, n in enumerate(nodes)}
222
+ for i in range(len(nodes)-1):
223
+ ax.annotate("", xy=pos[nodes[i+1]], xytext=pos[nodes[i]], arrowprops=dict(arrowstyle="->", lw=1.5))
224
+ ax.text(pos[nodes[i]][0]+0.05, pos[nodes[i]][1], nodes[i], va='center')
225
+ ax.text(pos['sT'][0]+0.05, pos['sT'][1], 'sT', va='center', fontweight='bold')
226
+ ax.annotate("Update V(s) using G", xy=(0.3, 0.9), xytext=(0.3, 0.15), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.3"))
227
+
228
+ def plot_mcts(ax):
229
+ """Monte Carlo: Monte Carlo Tree Search (MCTS) tree diagram."""
230
+ G = nx.balanced_tree(2, 2, create_using=nx.DiGraph())
231
+ pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot') if 'pygraphviz' in globals() else nx.shell_layout(G)
232
+ # Simple tree fallback
233
+ pos = {0:(0,0), 1:(-1,-1), 2:(1,-1), 3:(-1.5,-2), 4:(-0.5,-2), 5:(0.5,-2), 6:(1.5,-2)}
234
+ nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=300, node_color='lightyellow', edgecolors='black')
235
+ nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
236
+ ax.set_title("MCTS Tree", fontsize=12, fontweight='bold')
237
+ ax.axis('off')
238
+
239
+ def plot_importance_sampling(ax):
240
+ """Monte Carlo: Importance Sampling Ratio Flow."""
241
+ ax.axis('off')
242
+ ax.set_title("Importance Sampling", fontsize=12, fontweight='bold')
243
+ ax.text(0.5, 0.8, r"$\pi(a|s)$", bbox=dict(boxstyle="circle", fc="lightgreen"), ha='center')
244
+ ax.text(0.5, 0.2, r"$b(a|s)$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
245
+ ax.annotate(r"$\rho = \frac{\pi}{b}$", xy=(0.7, 0.5), fontsize=15)
246
+ ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<->", lw=2))
247
+
248
+ def plot_td_backup(ax):
249
+ """Temporal Difference: TD(0) 1-step backup."""
250
+ ax.axis('off')
251
+ ax.set_title("TD(0) Backup", fontsize=12, fontweight='bold')
252
+ ax.text(0.5, 0.8, "s", bbox=dict(boxstyle="circle", fc="white"), ha='center')
253
+ ax.text(0.5, 0.2, "s'", bbox=dict(boxstyle="circle", fc="white"), ha='center')
254
+ ax.annotate(r"$R + \gamma V(s')$", xy=(0.5, 0.4), ha='center', color='blue')
255
+ ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<-", lw=2))
256
+
257
+ def plot_nstep_td(ax):
258
+ """Temporal Difference: n-step TD backup."""
259
+ ax.axis('off')
260
+ ax.set_title("n-step TD Backup", fontsize=12, fontweight='bold')
261
+ for i in range(4):
262
+ ax.text(0.5, 0.9-i*0.2, f"s_{i}", bbox=dict(boxstyle="circle", fc="white"), ha='center', fontsize=8)
263
+ if i < 3: ax.annotate("", xy=(0.5, 0.75-i*0.2), xytext=(0.5, 0.85-i*0.2), arrowprops=dict(arrowstyle="->"))
264
+ ax.annotate(r"$G_t^{(n)}$", xy=(0.7, 0.5), fontsize=12, color='red')
265
+
266
+ def plot_eligibility_traces(ax):
267
+ """Temporal Difference: TD(lambda) Eligibility Traces decay curve."""
268
+ t = np.arange(0, 50)
269
+ # Simulate multiple highlights (visits)
270
+ trace = np.zeros_like(t, dtype=float)
271
+ visits = [5, 20, 35]
272
+ for v in visits:
273
+ trace[v:] += (0.8 ** np.arange(len(t)-v))
274
+ ax.plot(t, trace, color='brown', lw=2)
275
+ ax.set_title(r"Eligibility Trace $z_t(\lambda)$", fontsize=12, fontweight='bold')
276
+ ax.set_xlabel("Time")
277
+ ax.fill_between(t, trace, color='brown', alpha=0.1)
278
+
279
+ def plot_sarsa_backup(ax):
280
+ """Temporal Difference: SARSA (On-policy) backup."""
281
+ ax.axis('off')
282
+ ax.set_title("SARSA Backup", fontsize=12, fontweight='bold')
283
+ ax.text(0.5, 0.9, "(s,a)", ha='center')
284
+ ax.text(0.5, 0.1, "(s',a')", ha='center')
285
+ ax.annotate("", xy=(0.5, 0.2), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='orange'))
286
+ ax.text(0.6, 0.5, "On-policy", rotation=90)
287
+
288
+ def plot_q_learning_backup(ax):
289
+ """Temporal Difference: Q-Learning (Off-policy) backup."""
290
+ ax.axis('off')
291
+ ax.set_title("Q-Learning Backup", fontsize=12, fontweight='bold')
292
+ ax.text(0.5, 0.9, "(s,a)", ha='center')
293
+ ax.text(0.5, 0.1, r"$\max_{a'} Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="lightcyan"))
294
+ ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='blue'))
295
+
296
+ def plot_double_q(ax):
297
+ """Temporal Difference: Double Q-Learning / Double DQN."""
298
+ ax.axis('off')
299
+ ax.set_title("Double Q-Learning", fontsize=12, fontweight='bold')
300
+ ax.text(0.5, 0.8, "Network A", bbox=dict(fc="lightyellow"), ha='center')
301
+ ax.text(0.5, 0.2, "Network B", bbox=dict(fc="lightcyan"), ha='center')
302
+ ax.annotate("Select $a^*$", xy=(0.3, 0.8), xytext=(0.5, 0.85), arrowprops=dict(arrowstyle="->"))
303
+ ax.annotate("Eval $Q(s', a^*)$", xy=(0.7, 0.2), xytext=(0.5, 0.15), arrowprops=dict(arrowstyle="->"))
304
+
305
+ def plot_dueling_dqn(ax):
306
+ """Temporal Difference: Dueling DQN Architecture."""
307
+ ax.axis('off')
308
+ ax.set_title("Dueling DQN", fontsize=12, fontweight='bold')
309
+ ax.text(0.1, 0.5, "Backbone", bbox=dict(fc="lightgrey"), ha='center', rotation=90)
310
+ ax.text(0.5, 0.7, "V(s)", bbox=dict(fc="lightgreen"), ha='center')
311
+ ax.text(0.5, 0.3, "A(s,a)", bbox=dict(fc="lightblue"), ha='center')
312
+ ax.text(0.9, 0.5, "Q(s,a)", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
313
+ ax.annotate("", xy=(0.35, 0.7), xytext=(0.15, 0.55), arrowprops=dict(arrowstyle="->"))
314
+ ax.annotate("", xy=(0.35, 0.3), xytext=(0.15, 0.45), arrowprops=dict(arrowstyle="->"))
315
+ ax.annotate("", xy=(0.75, 0.55), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
316
+ ax.annotate("", xy=(0.75, 0.45), xytext=(0.6, 0.3), arrowprops=dict(arrowstyle="->"))
317
+
318
+ def plot_prioritized_replay(ax):
319
+ """Temporal Difference: Prioritized Experience Replay (PER)."""
320
+ priorities = np.random.pareto(3, 100)
321
+ ax.hist(priorities, bins=20, color='teal', alpha=0.7)
322
+ ax.set_title("Prioritized Replay (TD-Error)", fontsize=12, fontweight='bold')
323
+ ax.set_xlabel("Priority $P_i$")
324
+ ax.set_ylabel("Count")
325
+
326
+ def plot_rainbow_dqn(ax):
327
+ """Temporal Difference: Rainbow DQN Composite."""
328
+ ax.axis('off')
329
+ ax.set_title("Rainbow DQN", fontsize=12, fontweight='bold')
330
+ features = ["Double", "Dueling", "PER", "Noisy", "Distributional", "n-step"]
331
+ for i, f in enumerate(features):
332
+ ax.text(0.5, 0.9 - i*0.15, f, ha='center', bbox=dict(boxstyle="round", fc="ghostwhite"), fontsize=8)
333
+
334
+ def plot_linear_fa(ax):
335
+ """Function Approximation: Linear Function Approximation."""
336
+ ax.axis('off')
337
+ ax.set_title("Linear Function Approx", fontsize=12, fontweight='bold')
338
+ ax.text(0.5, 0.8, r"$\phi(s)$ Features", ha='center', bbox=dict(fc="white"))
339
+ ax.text(0.5, 0.2, r"$w^T \phi(s)$", ha='center', bbox=dict(fc="lightgrey"))
340
+ ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="->", lw=2))
341
+
342
+ def plot_nn_layers(ax):
343
+ """Function Approximation: Neural Network Layers diagram."""
344
+ ax.axis('off')
345
+ ax.set_title("NN Layers (Deep RL)", fontsize=12, fontweight='bold')
346
+ layers = [4, 8, 8, 2]
347
+ for i, l in enumerate(layers):
348
+ for j in range(l):
349
+ ax.scatter(i*0.3, j*0.1 - l*0.05, s=20, c='black')
350
+ ax.set_xlim(-0.1, 1.0)
351
+ ax.set_ylim(-0.5, 0.5)
352
+
353
+ def plot_computation_graph(ax):
354
+ """Function Approximation: Computation Graph / Backprop Flow."""
355
+ ax.axis('off')
356
+ ax.set_title("Computation Graph (DAG)", fontsize=12, fontweight='bold')
357
+ ax.text(0.1, 0.5, "Input", bbox=dict(boxstyle="circle", fc="white"))
358
+ ax.text(0.5, 0.5, "Op", bbox=dict(boxstyle="square", fc="lightgrey"))
359
+ ax.text(0.9, 0.5, "Loss", bbox=dict(boxstyle="circle", fc="salmon"))
360
+ ax.annotate("", xy=(0.35, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
361
+ ax.annotate("", xy=(0.75, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
362
+ ax.annotate("Grad", xy=(0.1, 0.3), xytext=(0.9, 0.3), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.2"))
363
+
364
+ def plot_target_network(ax):
365
+ """Function Approximation: Target Network concept."""
366
+ ax.axis('off')
367
+ ax.set_title("Target Network Updates", fontsize=12, fontweight='bold')
368
+ ax.text(0.3, 0.8, r"$Q_\theta$ (Active)", bbox=dict(fc="lightgreen"))
369
+ ax.text(0.7, 0.8, r"$Q_{\theta^-}$ (Target)", bbox=dict(fc="lightblue"))
370
+ ax.annotate("periodic copy", xy=(0.6, 0.8), xytext=(0.4, 0.8), arrowprops=dict(arrowstyle="<-", ls='--'))
371
+
372
+ def plot_ppo_clip(ax):
373
+ """Policy Gradients: PPO Clipped Surrogate Objective."""
374
+ epsilon = 0.2
375
+ r = np.linspace(0.5, 1.5, 100)
376
+ advantage = 1.0
377
+ surr1 = r * advantage
378
+ surr2 = np.clip(r, 1-epsilon, 1+epsilon) * advantage
379
+ ax.plot(r, surr1, '--', label="r*A")
380
+ ax.plot(r, np.minimum(surr1, surr2), 'r', label="min(r*A, clip*A)")
381
+ ax.set_title("PPO-Clip Objective", fontsize=12, fontweight='bold')
382
+ ax.legend(fontsize=8)
383
+ ax.axvline(1, color='gray', linestyle=':')
384
+
385
+ def plot_trpo_trust_region(ax):
386
+ """Policy Gradients: TRPO Trust Region / KL Constraint."""
387
+ ax.set_title("TRPO Trust Region", fontsize=12, fontweight='bold')
388
+ circle = plt.Circle((0.5, 0.5), 0.3, color='blue', fill=False, label="KL Constraint")
389
+ ax.add_artist(circle)
390
+ ax.scatter(0.5, 0.5, c='black', label=r"$\pi_{old}$")
391
+ ax.arrow(0.5, 0.5, 0.15, 0.1, head_width=0.03, color='red', label="Update")
392
+ ax.set_xlim(0, 1); ax.set_ylim(0, 1)
393
+ ax.axis('off')
394
+
395
+ def plot_a3c_multi_worker(ax):
396
+ """Actor-Critic: Asynchronous Multi-worker (A3C)."""
397
+ ax.axis('off')
398
+ ax.set_title("A3C Multi-worker", fontsize=12, fontweight='bold')
399
+ ax.text(0.5, 0.8, "Global Parameters", bbox=dict(fc="gold"), ha='center')
400
+ for i in range(3):
401
+ ax.text(0.2 + i*0.3, 0.2, f"Worker {i+1}", bbox=dict(fc="lightgrey"), ha='center', fontsize=8)
402
+ ax.annotate("", xy=(0.5, 0.7), xytext=(0.2 + i*0.3, 0.3), arrowprops=dict(arrowstyle="<->"))
403
+
404
+ def plot_sac_arch(ax):
405
+ """Actor-Critic: SAC (Entropy-regularized)."""
406
+ ax.axis('off')
407
+ ax.set_title("SAC Architecture", fontsize=12, fontweight='bold')
408
+ ax.text(0.5, 0.7, "Actor", bbox=dict(fc="lightgreen"), ha='center')
409
+ ax.text(0.5, 0.3, "Entropy Bonus", bbox=dict(fc="salmon"), ha='center')
410
+ ax.text(0.1, 0.5, "State", ha='center')
411
+ ax.text(0.9, 0.5, "Action", ha='center')
412
+ ax.annotate("", xy=(0.4, 0.7), xytext=(0.15, 0.5), arrowprops=dict(arrowstyle="->"))
413
+ ax.annotate("", xy=(0.5, 0.55), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="->"))
414
+ ax.annotate("", xy=(0.85, 0.5), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
415
+
416
+ def plot_softmax_exploration(ax):
417
+ """Exploration: Softmax / Boltzmann probabilities."""
418
+ x = np.arange(4)
419
+ logits = [1, 2, 5, 3]
420
+ for tau in [0.5, 1.0, 5.0]:
421
+ probs = np.exp(np.array(logits)/tau)
422
+ probs /= probs.sum()
423
+ ax.plot(x, probs, marker='o', label=rf"$\tau={tau}$")
424
+ ax.set_title("Softmax Exploration", fontsize=12, fontweight='bold')
425
+ ax.legend(fontsize=8)
426
+ ax.set_xticks(x)
427
+
428
+ def plot_ucb_confidence(ax):
429
+ """Exploration: Upper Confidence Bound (UCB)."""
430
+ actions = ['A1', 'A2', 'A3']
431
+ means = [0.6, 0.8, 0.5]
432
+ conf = [0.3, 0.1, 0.4]
433
+ ax.bar(actions, means, yerr=conf, capsize=10, color='skyblue', label='Mean Q')
434
+ ax.set_title("UCB Action Values", fontsize=12, fontweight='bold')
435
+ ax.set_ylim(0, 1.2)
436
+
437
+ def plot_intrinsic_motivation(ax):
438
+ """Exploration: Intrinsic Motivation / Curiosity."""
439
+ ax.axis('off')
440
+ ax.set_title("Intrinsic Motivation", fontsize=12, fontweight='bold')
441
+ ax.text(0.3, 0.5, "World Model", bbox=dict(fc="lightyellow"), ha='center')
442
+ ax.text(0.7, 0.5, "Prediction\nError", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
443
+ ax.annotate("", xy=(0.58, 0.5), xytext=(0.42, 0.5), arrowprops=dict(arrowstyle="->"))
444
+ ax.text(0.85, 0.5, r"$R_{int}$", fontweight='bold')
445
+
446
+ def plot_entropy_bonus(ax):
447
+ """Exploration: Entropy Regularization curve."""
448
+ p = np.linspace(0.01, 0.99, 50)
449
+ entropy = -(p * np.log(p) + (1-p) * np.log(1-p))
450
+ ax.plot(p, entropy, color='purple')
451
+ ax.set_title(r"Entropy $H(\pi)$", fontsize=12, fontweight='bold')
452
+ ax.set_xlabel("$P(a)$")
453
+
454
+ def plot_options_framework(ax):
455
+ """Hierarchical RL: Options Framework."""
456
+ ax.axis('off')
457
+ ax.set_title("Options Framework", fontsize=12, fontweight='bold')
458
+ ax.text(0.5, 0.8, r"High-level policy" + "\n" + r"$\pi_{hi}$", bbox=dict(fc="lightblue"), ha='center')
459
+ ax.text(0.2, 0.2, "Option 1", bbox=dict(fc="ivory"), ha='center')
460
+ ax.text(0.8, 0.2, "Option 2", bbox=dict(fc="ivory"), ha='center')
461
+ ax.annotate("", xy=(0.3, 0.3), xytext=(0.45, 0.7), arrowprops=dict(arrowstyle="->"))
462
+ ax.annotate("", xy=(0.7, 0.3), xytext=(0.55, 0.7), arrowprops=dict(arrowstyle="->"))
463
+
464
+ def plot_feudal_networks(ax):
465
+ """Hierarchical RL: Feudal Networks / Hierarchy."""
466
+ ax.axis('off')
467
+ ax.set_title("Feudal Networks", fontsize=12, fontweight='bold')
468
+ ax.text(0.5, 0.85, "Manager", bbox=dict(fc="plum"), ha='center')
469
+ ax.text(0.5, 0.15, "Worker", bbox=dict(fc="wheat"), ha='center')
470
+ ax.annotate("Goal $g_t$", xy=(0.5, 0.3), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
471
+
472
+ def plot_world_model(ax):
473
+ """Model-Based RL: Learned Dynamics Model."""
474
+ ax.axis('off')
475
+ ax.set_title("World Model (Dynamics)", fontsize=12, fontweight='bold')
476
+ ax.text(0.1, 0.5, "(s,a)", ha='center')
477
+ ax.text(0.5, 0.5, r"$\hat{P}$", bbox=dict(boxstyle="circle", fc="lightgrey"), ha='center')
478
+ ax.text(0.9, 0.7, r"$\hat{s}'$", ha='center')
479
+ ax.text(0.9, 0.3, r"$\hat{r}$", ha='center')
480
+ ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
481
+ ax.annotate("", xy=(0.8, 0.65), xytext=(0.6, 0.55), arrowprops=dict(arrowstyle="->"))
482
+ ax.annotate("", xy=(0.8, 0.35), xytext=(0.6, 0.45), arrowprops=dict(arrowstyle="->"))
483
+
484
+ def plot_model_planning(ax):
485
+ """Model-Based RL: Planning / Rollouts in imagination."""
486
+ ax.axis('off')
487
+ ax.set_title("Model-Based Planning", fontsize=12, fontweight='bold')
488
+ ax.text(0.1, 0.5, "Real s", ha='center', fontweight='bold')
489
+ for i in range(3):
490
+ ax.annotate("", xy=(0.3+i*0.2, 0.5+(i%2)*0.1), xytext=(0.1+i*0.2, 0.5), arrowprops=dict(arrowstyle="->", color='gray'))
491
+ ax.text(0.3+i*0.2, 0.55+(i%2)*0.1, "imagined", fontsize=7)
492
+
493
+ def plot_offline_rl(ax):
494
+ """Offline RL: Fixed dataset of trajectories."""
495
+ ax.axis('off')
496
+ ax.set_title("Offline RL Dataset", fontsize=12, fontweight='bold')
497
+ ax.text(0.5, 0.5, r"Static" + "\n" + r"Dataset" + "\n" + r"$\mathcal{D}$", bbox=dict(boxstyle="round", fc="lightgrey"), ha='center')
498
+ ax.annotate("No interaction", xy=(0.5, 0.9), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", color='red'))
499
+ ax.scatter([0.2, 0.8, 0.3, 0.7], [0.8, 0.8, 0.2, 0.2], marker='x', color='blue')
500
+
501
+ def plot_cql_regularization(ax):
502
+ """Offline RL: CQL regularization visualization."""
503
+ q = np.linspace(-5, 5, 100)
504
+ penalty = q**2 * 0.1
505
+ ax.plot(q, penalty, 'r', label='CQL Penalty')
506
+ ax.set_title("CQL Regularization", fontsize=12, fontweight='bold')
507
+ ax.set_xlabel("Q-value")
508
+ ax.legend(fontsize=8)
509
+
510
+ def plot_multi_agent_interaction(ax):
511
+ """Multi-Agent RL: Agents communicating or competing."""
512
+ G = nx.complete_graph(3)
513
+ pos = nx.spring_layout(G)
514
+ nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color=['red', 'blue', 'green'])
515
+ nx.draw_networkx_edges(ax=ax, G=G, pos=pos, style='dashed')
516
+ ax.set_title("Multi-Agent Interaction", fontsize=12, fontweight='bold')
517
+ ax.axis('off')
518
+
519
+ def plot_ctde(ax):
520
+ """Multi-Agent RL: Centralized Training Decentralized Execution (CTDE)."""
521
+ ax.axis('off')
522
+ ax.set_title("CTDE Architecture", fontsize=12, fontweight='bold')
523
+ ax.text(0.5, 0.8, "Centralized Critic", bbox=dict(fc="gold"), ha='center')
524
+ ax.text(0.2, 0.2, "Agent 1", bbox=dict(fc="lightblue"), ha='center')
525
+ ax.text(0.8, 0.2, "Agent 2", bbox=dict(fc="lightblue"), ha='center')
526
+ ax.annotate("", xy=(0.5, 0.7), xytext=(0.25, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
527
+ ax.annotate("", xy=(0.5, 0.7), xytext=(0.75, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
528
+
529
+ def plot_payoff_matrix(ax):
530
+ """Multi-Agent RL: Cooperative / Competitive Payoff Matrix."""
531
+ matrix = np.array([[(3,3), (0,5)], [(5,0), (1,1)]])
532
+ ax.axis('off')
533
+ ax.set_title("Payoff Matrix (Prisoner's)", fontsize=12, fontweight='bold')
534
+ for i in range(2):
535
+ for j in range(2):
536
+ ax.text(j, 1-i, str(matrix[i, j]), ha='center', va='center', bbox=dict(fc="white"))
537
+ ax.set_xlim(-0.5, 1.5); ax.set_ylim(-0.5, 1.5)
538
+
539
+ def plot_irl_reward_inference(ax):
540
+ """Inverse RL: Infer reward from expert demonstrations."""
541
+ ax.axis('off')
542
+ ax.set_title("Inferred Reward Heatmap", fontsize=12, fontweight='bold')
543
+ grid = np.zeros((5, 5))
544
+ grid[2:4, 2:4] = 1.0 # Expert path
545
+ ax.imshow(grid, cmap='hot')
546
+
547
+ def plot_gail_flow(ax):
548
+ """Inverse RL: GAIL (Generative Adversarial Imitation Learning)."""
549
+ ax.axis('off')
550
+ ax.set_title("GAIL Architecture", fontsize=12, fontweight='bold')
551
+ ax.text(0.2, 0.8, "Expert Data", bbox=dict(fc="lightgrey"), ha='center')
552
+ ax.text(0.2, 0.2, "Policy (Gen)", bbox=dict(fc="lightgreen"), ha='center')
553
+ ax.text(0.8, 0.5, "Discriminator", bbox=dict(boxstyle="square", fc="salmon"), ha='center')
554
+ ax.annotate("", xy=(0.6, 0.55), xytext=(0.35, 0.75), arrowprops=dict(arrowstyle="->"))
555
+ ax.annotate("", xy=(0.6, 0.45), xytext=(0.35, 0.25), arrowprops=dict(arrowstyle="->"))
556
+
557
+ def plot_meta_rl_nested_loop(ax):
558
+ """Meta-RL: Outer loop (meta) + inner loop (adaptation)."""
559
+ ax.axis('off')
560
+ ax.set_title("Meta-RL Loops", fontsize=12, fontweight='bold')
561
+ ax.add_patch(plt.Circle((0.5, 0.5), 0.4, fill=False, ls='--'))
562
+ ax.add_patch(plt.Circle((0.5, 0.5), 0.2, fill=False))
563
+ ax.text(0.5, 0.5, "Inner\nLoop", ha='center', fontsize=8)
564
+ ax.text(0.5, 0.8, "Outer Loop", ha='center', fontsize=10)
565
+
566
+ def plot_task_distribution(ax):
567
+ """Meta-RL: Multiple MDPs from distribution."""
568
+ ax.axis('off')
569
+ ax.set_title("Task Distribution", fontsize=12, fontweight='bold')
570
+ for i in range(3):
571
+ ax.text(0.2 + i*0.3, 0.5, f"Task {i+1}", bbox=dict(boxstyle="round", fc="ivory"), fontsize=8)
572
+ ax.annotate("sample", xy=(0.5, 0.8), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="<-"))
573
+
574
+ def plot_replay_buffer(ax):
575
+ """Advanced: Experience Replay Buffer (FIFO)."""
576
+ ax.axis('off')
577
+ ax.set_title("Experience Replay Buffer", fontsize=12, fontweight='bold')
578
+ for i in range(5):
579
+ ax.add_patch(plt.Rectangle((0.1+i*0.15, 0.4), 0.1, 0.2, fill=True, color='lightgrey'))
580
+ ax.text(0.15+i*0.15, 0.5, f"e_{i}", ha='center')
581
+ ax.annotate("In", xy=(0.05, 0.5), xytext=(-0.1, 0.5), arrowprops=dict(arrowstyle="->"), annotation_clip=False)
582
+ ax.annotate("Out (Batch)", xy=(0.85, 0.5), xytext=(1.0, 0.5), arrowprops=dict(arrowstyle="<-"), annotation_clip=False)
583
+
584
+ def plot_state_visitation(ax):
585
+ """Advanced: State Visitation / Occupancy Measure."""
586
+ data = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], 1000)
587
+ ax.hexbin(data[:, 0], data[:, 1], gridsize=15, cmap='Blues')
588
+ ax.set_title("State Visitation Heatmap", fontsize=12, fontweight='bold')
589
+
590
+ def plot_regret_curve(ax):
591
+ """Advanced: Regret / Cumulative Regret."""
592
+ t = np.arange(100)
593
+ regret = np.sqrt(t) + np.random.normal(0, 0.5, 100)
594
+ ax.plot(t, regret, color='red', label='Sub-linear Regret')
595
+ ax.set_title("Cumulative Regret", fontsize=12, fontweight='bold')
596
+ ax.set_xlabel("Time")
597
+ ax.legend(fontsize=8)
598
+
599
+ def plot_attention_weights(ax):
600
+ """Advanced: Attention Mechanisms (Heatmap)."""
601
+ weights = np.random.rand(5, 5)
602
+ ax.imshow(weights, cmap='viridis')
603
+ ax.set_title("Attention Weight Matrix", fontsize=12, fontweight='bold')
604
+ ax.set_xticks([]); ax.set_yticks([])
605
+
606
+ def plot_diffusion_policy(ax):
607
+ """Advanced: Diffusion Policy denoising steps."""
608
+ ax.axis('off')
609
+ ax.set_title("Diffusion Policy (Denoising)", fontsize=12, fontweight='bold')
610
+ for i in range(4):
611
+ ax.scatter(0.1+i*0.25, 0.5, s=100/(i+1), c='black', alpha=1.0 - i*0.2)
612
+ if i < 3: ax.annotate("", xy=(0.25+i*0.25, 0.5), xytext=(0.15+i*0.25, 0.5), arrowprops=dict(arrowstyle="->"))
613
+ ax.text(0.5, 0.3, "Noise $\\rightarrow$ Action", ha='center', fontsize=8)
614
+
615
+ def plot_gnn_rl(ax):
616
+ """Advanced: Graph Neural Networks for RL."""
617
+ G = nx.star_graph(4)
618
+ pos = nx.spring_layout(G)
619
+ nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=200, node_color='orange')
620
+ nx.draw_networkx_edges(ax=ax, G=G, pos=pos)
621
+ ax.set_title("GNN Message Passing", fontsize=12, fontweight='bold')
622
+ ax.axis('off')
623
+
624
+ def plot_latent_space(ax):
625
+ """Advanced: World Model / Latent Space."""
626
+ ax.axis('off')
627
+ ax.set_title("Latent Space (VAE/Dreamer)", fontsize=12, fontweight='bold')
628
+ ax.text(0.1, 0.5, "Image", bbox=dict(fc="lightgrey"), ha='center')
629
+ ax.text(0.5, 0.5, "Latent $z$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
630
+ ax.text(0.9, 0.5, "Reconstruction", bbox=dict(fc="lightgrey"), ha='center')
631
+ ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
632
+ ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
633
+
634
+ def plot_convergence_log(ax):
635
+ """Advanced: Convergence Analysis Plots (Log-scale)."""
636
+ iterations = np.arange(1, 100)
637
+ error = 10 / iterations**2
638
+ ax.loglog(iterations, error, color='green')
639
+ ax.set_title("Value Convergence (Log)", fontsize=12, fontweight='bold')
640
+ ax.set_xlabel("Iterations")
641
+ ax.set_ylabel("Error")
642
+ ax.grid(True, which="both", ls="-", alpha=0.3)
643
+
644
+ def plot_expected_sarsa_backup(ax):
645
+ """Temporal Difference: Expected SARSA (Expectation over policy)."""
646
+ ax.axis('off')
647
+ ax.set_title("Expected SARSA Backup", fontsize=12, fontweight='bold')
648
+ ax.text(0.5, 0.9, "(s,a)", ha='center')
649
+ ax.text(0.5, 0.1, r"$\sum_{a'} \pi(a'|s') Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="ivory"))
650
+ ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='purple'))
651
+
652
+ def plot_reinforce_flow(ax):
653
+ """Policy Gradients: REINFORCE (Full trajectory flow)."""
654
+ ax.axis('off')
655
+ ax.set_title("REINFORCE Flow", fontsize=12, fontweight='bold')
656
+ steps = ["s0", "a0", "r1", "s1", "...", "GT"]
657
+ for i, s in enumerate(steps):
658
+ ax.text(0.1 + i*0.15, 0.5, s, bbox=dict(boxstyle="circle", fc="white"))
659
+ ax.annotate(r"$\nabla_\theta J \propto G_t \nabla \ln \pi$", xy=(0.5, 0.8), ha='center', fontsize=12, color='darkgreen')
660
+
661
+ def plot_advantage_scaled_grad(ax):
662
+ """Policy Gradients: Baseline / Advantage scaled gradient."""
663
+ ax.axis('off')
664
+ ax.set_title("Baseline Subtraction", fontsize=12, fontweight='bold')
665
+ ax.text(0.5, 0.8, r"$(G_t - b(s))$", bbox=dict(fc="salmon"), ha='center')
666
+ ax.text(0.5, 0.3, r"Scale $\nabla \ln \pi$", ha='center')
667
+ ax.annotate("", xy=(0.5, 0.4), xytext=(0.5, 0.7), arrowprops=dict(arrowstyle="->"))
668
+
669
+ def plot_skill_discovery(ax):
670
+ """Hierarchical RL: Skill Discovery (Unsupervised clusters)."""
671
+ np.random.seed(0)
672
+ for i in range(3):
673
+ center = np.random.randn(2) * 2
674
+ pts = np.random.randn(20, 2) * 0.5 + center
675
+ ax.scatter(pts[:, 0], pts[:, 1], alpha=0.6, label=f"Skill {i+1}")
676
+ ax.set_title("Skill Embedding Space", fontsize=12, fontweight='bold')
677
+ ax.legend(fontsize=8)
678
+
679
+ def plot_imagination_rollout(ax):
680
+ """Model-Based RL: Imagination-Augmented Rollouts (I2A)."""
681
+ ax.axis('off')
682
+ ax.set_title("Imagination Rollout (I2A)", fontsize=12, fontweight='bold')
683
+ ax.text(0.1, 0.5, "Input s", ha='center')
684
+ ax.add_patch(plt.Rectangle((0.3, 0.3), 0.4, 0.4, fill=True, color='lavender'))
685
+ ax.text(0.5, 0.5, "Imagination\nModule", ha='center')
686
+ ax.annotate("Imagined Paths", xy=(0.8, 0.5), xytext=(0.5, 0.5), arrowprops=dict(arrowstyle="->", color='gray', connectionstyle="arc3,rad=0.3"))
687
+
688
+ def plot_policy_gradient_flow(ax):
689
+ """Policy Gradients: Gradient flow from reward to log-prob (DAG)."""
690
+ ax.axis('off')
691
+ ax.set_title("Policy Gradient Flow (DAG)", fontsize=12, fontweight='bold')
692
+
693
+ bbox_props = dict(boxstyle="round,pad=0.5", fc="lightgrey", ec="black", lw=1.5)
694
+ ax.text(0.1, 0.8, r"Trajectory $\tau$", ha="center", va="center", bbox=bbox_props)
695
+ ax.text(0.5, 0.8, r"Reward $R(\tau)$", ha="center", va="center", bbox=bbox_props)
696
+ ax.text(0.1, 0.2, r"Log-Prob $\log \pi_\theta$", ha="center", va="center", bbox=bbox_props)
697
+ ax.text(0.7, 0.5, r"$\nabla_\theta J(\theta)$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="gold", ec="black"))
698
+
699
+ # Draw arrows
700
+ ax.annotate("", xy=(0.35, 0.8), xytext=(0.2, 0.8), arrowprops=dict(arrowstyle="->", lw=2))
701
+ ax.annotate("", xy=(0.7, 0.65), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
702
+ ax.annotate("", xy=(0.6, 0.4), xytext=(0.25, 0.2), arrowprops=dict(arrowstyle="->", lw=2))
703
+
704
+ def plot_actor_critic_arch(ax):
705
+ """Actor-Critic: Three-network diagram (TD3 - actor + two critics)."""
706
+ ax.axis('off')
707
+ ax.set_title("TD3 Architecture Diagram", fontsize=12, fontweight='bold')
708
+
709
+ # State input
710
+ ax.text(0.1, 0.5, r"State" + "\n" + r"$s$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.5", fc="lightblue"))
711
+
712
+ # Networks
713
+ net_props = dict(boxstyle="square,pad=0.8", fc="lightgreen", ec="black")
714
+ ax.text(0.5, 0.8, r"Actor $\pi_\phi$", ha="center", va="center", bbox=net_props)
715
+ ax.text(0.5, 0.5, r"Critic 1 $Q_{\theta_1}$", ha="center", va="center", bbox=net_props)
716
+ ax.text(0.5, 0.2, r"Critic 2 $Q_{\theta_2}$", ha="center", va="center", bbox=net_props)
717
+
718
+ # Outputs
719
+ ax.text(0.8, 0.8, "Action $a$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="coral"))
720
+ ax.text(0.8, 0.35, "Min Q-value", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.3", fc="gold"))
721
+
722
+ # Connections
723
+ kwargs = dict(arrowstyle="->", lw=1.5)
724
+ ax.annotate("", xy=(0.38, 0.8), xytext=(0.15, 0.55), arrowprops=kwargs) # S -> Actor
725
+ ax.annotate("", xy=(0.38, 0.5), xytext=(0.15, 0.5), arrowprops=kwargs) # S -> C1
726
+ ax.annotate("", xy=(0.38, 0.2), xytext=(0.15, 0.45), arrowprops=kwargs) # S -> C2
727
+ ax.annotate("", xy=(0.73, 0.8), xytext=(0.62, 0.8), arrowprops=kwargs) # Actor -> Action
728
+ ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.5), arrowprops=kwargs) # C1 -> Min
729
+ ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.2), arrowprops=kwargs) # C2 -> Min
730
+
731
+ def plot_epsilon_decay(ax):
732
+ """Exploration: ε-Greedy Strategy Decay Curve."""
733
+ episodes = np.arange(0, 1000)
734
+ epsilon = np.maximum(0.01, np.exp(-0.005 * episodes)) # Exponential decay
735
+
736
+ ax.plot(episodes, epsilon, color='purple', lw=2)
737
+ ax.set_title(r"$\epsilon$-Greedy Decay Curve", fontsize=12, fontweight='bold')
738
+ ax.set_xlabel("Episodes")
739
+ ax.set_ylabel(r"Probability $\epsilon$")
740
+ ax.grid(True, linestyle='--', alpha=0.6)
741
+ ax.fill_between(episodes, epsilon, color='purple', alpha=0.1)
742
+
743
+ def plot_learning_curve(ax):
744
+ """Advanced / Misc: Learning Curve with Confidence Bands."""
745
+ steps = np.linspace(0, 1e6, 100)
746
+ # Simulate a learning curve converging to a maximum
747
+ mean_return = 100 * (1 - np.exp(-5e-6 * steps)) + np.random.normal(0, 2, len(steps))
748
+ std_dev = 15 * np.exp(-2e-6 * steps) # Variance decreases as policy stabilizes
749
+
750
+ ax.plot(steps, mean_return, color='blue', lw=2, label="PPO (Mean)")
751
+ ax.fill_between(steps, mean_return - std_dev, mean_return + std_dev, color='blue', alpha=0.2, label="±1 Std Dev")
752
+
753
+ ax.set_title("Learning Curve (Return vs Steps)", fontsize=12, fontweight='bold')
754
+ ax.set_xlabel("Environment Steps")
755
+ ax.set_ylabel("Average Episodic Return")
756
+ ax.legend(loc="lower right")
757
+ ax.grid(True, linestyle='--', alpha=0.6)
758
+
759
+ def main():
760
+ # Figure 1: MDP & Environment (7 plots)
761
+ fig1, gs1 = setup_figure("RL: MDP & Environment", 2, 4)
762
+
763
+ plot_agent_env_loop(fig1.add_subplot(gs1[0, 0]))
764
+ plot_mdp_graph(fig1.add_subplot(gs1[0, 1]))
765
+ plot_trajectory(fig1.add_subplot(gs1[0, 2]))
766
+ plot_continuous_space(fig1.add_subplot(gs1[0, 3]))
767
+ plot_reward_landscape(fig1, gs1) # projection='3d' handled inside
768
+ plot_discount_decay(fig1.add_subplot(gs1[1, 1]))
769
+ # row 5 (State Transition Graph) is basically plot_mdp_graph
770
+
771
+ # Layout handled by constrained_layout=True
772
+
773
+ # Figure 2: Value, Policy & Dynamic Programming
774
+ fig2, gs2 = setup_figure("RL: Value, Policy & Dynamic Programming", 2, 4)
775
+ plot_value_heatmap(fig2.add_subplot(gs2[0, 0]))
776
+ plot_action_value_q(fig2.add_subplot(gs2[0, 1]))
777
+ plot_policy_arrows(fig2.add_subplot(gs2[0, 2]))
778
+ plot_advantage_function(fig2.add_subplot(gs2[0, 3]))
779
+ plot_backup_diagram(fig2.add_subplot(gs2[1, 0])) # Policy Eval
780
+ plot_policy_improvement(fig2.add_subplot(gs2[1, 1]))
781
+ plot_value_iteration_backup(fig2.add_subplot(gs2[1, 2]))
782
+ plot_policy_iteration_cycle(fig2.add_subplot(gs2[1, 3]))
783
+
784
+ # Layout handled by constrained_layout=True
785
+
786
+ # Figure 3: Monte Carlo & Temporal Difference
787
+ fig3, gs3 = setup_figure("RL: Monte Carlo & Temporal Difference", 2, 4)
788
+ plot_mc_backup(fig3.add_subplot(gs3[0, 0]))
789
+ plot_mcts(fig3.add_subplot(gs3[0, 1]))
790
+ plot_importance_sampling(fig3.add_subplot(gs3[0, 2]))
791
+ plot_td_backup(fig3.add_subplot(gs3[0, 3]))
792
+ plot_nstep_td(fig3.add_subplot(gs3[1, 0]))
793
+ plot_eligibility_traces(fig3.add_subplot(gs3[1, 1]))
794
+ plot_sarsa_backup(fig3.add_subplot(gs3[1, 2]))
795
+ plot_q_learning_backup(fig3.add_subplot(gs3[1, 3]))
796
+
797
+ # Layout handled by constrained_layout=True
798
+
799
+ # Figure 4: TD Extensions & Function Approximation
800
+ fig4, gs4 = setup_figure("RL: TD Extensions & Function Approximation", 2, 4)
801
+ plot_double_q(fig4.add_subplot(gs4[0, 0]))
802
+ plot_dueling_dqn(fig4.add_subplot(gs4[0, 1]))
803
+ plot_prioritized_replay(fig4.add_subplot(gs4[0, 2]))
804
+ plot_rainbow_dqn(fig4.add_subplot(gs4[0, 3]))
805
+ plot_linear_fa(fig4.add_subplot(gs4[1, 0]))
806
+ plot_nn_layers(fig4.add_subplot(gs4[1, 1]))
807
+ plot_computation_graph(fig4.add_subplot(gs4[1, 2]))
808
+ plot_target_network(fig4.add_subplot(gs4[1, 3]))
809
+
810
+ # Layout handled by constrained_layout=True
811
+
812
+ # Figure 5: Policy Gradients, Actor-Critic & Exploration
813
+ fig5, gs5 = setup_figure("RL: Policy Gradients, Actor-Critic & Exploration", 2, 4)
814
+ plot_policy_gradient_flow(fig5.add_subplot(gs5[0, 0]))
815
+ plot_ppo_clip(fig5.add_subplot(gs5[0, 1]))
816
+ plot_trpo_trust_region(fig5.add_subplot(gs5[0, 2]))
817
+ plot_actor_critic_arch(fig5.add_subplot(gs5[0, 3]))
818
+ plot_a3c_multi_worker(fig5.add_subplot(gs5[1, 0]))
819
+ plot_sac_arch(fig5.add_subplot(gs5[1, 1]))
820
+ plot_softmax_exploration(fig5.add_subplot(gs5[1, 2]))
821
+ plot_ucb_confidence(fig5.add_subplot(gs5[1, 3]))
822
+
823
+ # Layout handled by constrained_layout=True
824
+
825
+ # Figure 6: Hierarchical, Model-Based & Offline RL
826
+ fig6, gs6 = setup_figure("RL: Hierarchical, Model-Based & Offline", 2, 4)
827
+ plot_options_framework(fig6.add_subplot(gs6[0, 0]))
828
+ plot_feudal_networks(fig6.add_subplot(gs6[0, 1]))
829
+ plot_world_model(fig6.add_subplot(gs6[0, 2]))
830
+ plot_model_planning(fig6.add_subplot(gs6[0, 3]))
831
+ plot_offline_rl(fig6.add_subplot(gs6[1, 0]))
832
+ plot_cql_regularization(fig6.add_subplot(gs6[1, 1]))
833
+ plot_epsilon_decay(fig6.add_subplot(gs6[1, 2])) # placeholder/spacer
834
+ plot_intrinsic_motivation(fig6.add_subplot(gs6[1, 3]))
835
+
836
+ # Layout handled by constrained_layout=True
837
+
838
+ # Figure 7: Multi-Agent, IRL & Meta-RL
839
+ fig7, gs7 = setup_figure("RL: Multi-Agent, IRL & Meta-RL", 2, 4)
840
+ plot_multi_agent_interaction(fig7.add_subplot(gs7[0, 0]))
841
+ plot_ctde(fig7.add_subplot(gs7[0, 1]))
842
+ plot_payoff_matrix(fig7.add_subplot(gs7[0, 2]))
843
+ plot_irl_reward_inference(fig7.add_subplot(gs7[0, 3]))
844
+ plot_gail_flow(fig7.add_subplot(gs7[1, 0]))
845
+ plot_meta_rl_nested_loop(fig7.add_subplot(gs7[1, 1]))
846
+ plot_task_distribution(fig7.add_subplot(gs7[1, 2]))
847
+
848
+ # Layout handled by constrained_layout=True
849
+
850
+ # Figure 8: Advanced / Miscellaneous Topics
851
+ fig8, gs8 = setup_figure("RL: Advanced & Miscellaneous", 2, 4)
852
+ plot_replay_buffer(fig8.add_subplot(gs8[0, 0]))
853
+ plot_state_visitation(fig8.add_subplot(gs8[0, 1]))
854
+ plot_regret_curve(fig8.add_subplot(gs8[0, 2]))
855
+ plot_attention_weights(fig8.add_subplot(gs8[0, 3]))
856
+ plot_diffusion_policy(fig8.add_subplot(gs8[1, 0]))
857
+ plot_gnn_rl(fig8.add_subplot(gs8[1, 1]))
858
+ plot_latent_space(fig8.add_subplot(gs8[1, 2]))
859
+ plot_convergence_log(fig8.add_subplot(gs8[1, 3]))
860
+
861
+ # Layout handled by constrained_layout=True
862
+ plt.show()
863
+
864
+ def save_all_graphs(output_dir="graphs"):
865
+ """Saves each of the 74 RL components as a separate PNG file."""
866
+ if not os.path.exists(output_dir):
867
+ os.makedirs(output_dir)
868
+
869
+ # Component-to-Function Mapping (Total 74 entries as per e.md rows)
870
+ mapping = {
871
+ "Agent-Environment Interaction Loop": plot_agent_env_loop,
872
+ "Markov Decision Process (MDP) Tuple": plot_mdp_graph,
873
+ "State Transition Graph": plot_mdp_graph,
874
+ "Trajectory / Episode Sequence": plot_trajectory,
875
+ "Continuous State/Action Space Visualization": plot_continuous_space,
876
+ "Reward Function / Landscape": plot_reward_landscape,
877
+ "Discount Factor (gamma) Effect": plot_discount_decay,
878
+ "State-Value Function V(s)": plot_value_heatmap,
879
+ "Action-Value Function Q(s,a)": plot_action_value_q,
880
+ "Policy pi(s) or pi(a|s)": plot_policy_arrows,
881
+ "Advantage Function A(s,a)": plot_advantage_function,
882
+ "Optimal Value Function V* / Q*": plot_value_heatmap,
883
+ "Policy Evaluation Backup": plot_backup_diagram,
884
+ "Policy Improvement": plot_policy_improvement,
885
+ "Value Iteration Backup": plot_value_iteration_backup,
886
+ "Policy Iteration Full Cycle": plot_policy_iteration_cycle,
887
+ "Monte Carlo Backup": plot_mc_backup,
888
+ "Monte Carlo Tree (MCTS)": plot_mcts,
889
+ "Importance Sampling Ratio": plot_importance_sampling,
890
+ "TD(0) Backup": plot_td_backup,
891
+ "Bootstrapping (general)": plot_td_backup,
892
+ "n-step TD Backup": plot_nstep_td,
893
+ "TD(lambda) & Eligibility Traces": plot_eligibility_traces,
894
+ "SARSA Update": plot_sarsa_backup,
895
+ "Q-Learning Update": plot_q_learning_backup,
896
+ "Expected SARSA": plot_expected_sarsa_backup,
897
+ "Double Q-Learning / Double DQN": plot_double_q,
898
+ "Dueling DQN Architecture": plot_dueling_dqn,
899
+ "Prioritized Experience Replay": plot_prioritized_replay,
900
+ "Rainbow DQN Components": plot_rainbow_dqn,
901
+ "Linear Function Approximation": plot_linear_fa,
902
+ "Neural Network Layers (MLP, CNN, RNN, Transformer)": plot_nn_layers,
903
+ "Computation Graph / Backpropagation Flow": plot_computation_graph,
904
+ "Target Network": plot_target_network,
905
+ "Policy Gradient Theorem": plot_policy_gradient_flow,
906
+ "REINFORCE Update": plot_reinforce_flow,
907
+ "Baseline / Advantage Subtraction": plot_advantage_scaled_grad,
908
+ "Trust Region (TRPO)": plot_trpo_trust_region,
909
+ "Proximal Policy Optimization (PPO)": plot_ppo_clip,
910
+ "Actor-Critic Architecture": plot_actor_critic_arch,
911
+ "Advantage Actor-Critic (A2C/A3C)": plot_a3c_multi_worker,
912
+ "Soft Actor-Critic (SAC)": plot_sac_arch,
913
+ "Twin Delayed DDPG (TD3)": plot_actor_critic_arch,
914
+ "epsilon-Greedy Strategy": plot_epsilon_decay,
915
+ "Softmax / Boltzmann Exploration": plot_softmax_exploration,
916
+ "Upper Confidence Bound (UCB)": plot_ucb_confidence,
917
+ "Intrinsic Motivation / Curiosity": plot_intrinsic_motivation,
918
+ "Entropy Regularization": plot_entropy_bonus,
919
+ "Options Framework": plot_options_framework,
920
+ "Feudal Networks / Hierarchical Actor-Critic": plot_feudal_networks,
921
+ "Skill Discovery": plot_skill_discovery,
922
+ "Learned Dynamics Model": plot_world_model,
923
+ "Model-Based Planning": plot_model_planning,
924
+ "Imagination-Augmented Agents (I2A)": plot_imagination_rollout,
925
+ "Offline Dataset": plot_offline_rl,
926
+ "Conservative Q-Learning (CQL)": plot_cql_regularization,
927
+ "Multi-Agent Interaction Graph": plot_multi_agent_interaction,
928
+ "Centralized Training Decentralized Execution (CTDE)": plot_ctde,
929
+ "Cooperative / Competitive Payoff Matrix": plot_payoff_matrix,
930
+ "Reward Inference": plot_irl_reward_inference,
931
+ "Generative Adversarial Imitation Learning (GAIL)": plot_gail_flow,
932
+ "Meta-RL Architecture": plot_meta_rl_nested_loop,
933
+ "Task Distribution Visualization": plot_task_distribution,
934
+ "Experience Replay Buffer": plot_replay_buffer,
935
+ "State Visitation / Occupancy Measure": plot_state_visitation,
936
+ "Learning Curve": plot_learning_curve,
937
+ "Regret / Cumulative Regret": plot_regret_curve,
938
+ "Attention Mechanisms (Transformers in RL)": plot_attention_weights,
939
+ "Diffusion Policy": plot_diffusion_policy,
940
+ "Graph Neural Networks for RL": plot_gnn_rl,
941
+ "World Model / Latent Space": plot_latent_space,
942
+ "Convergence Analysis Plots": plot_convergence_log
943
+ }
944
+
945
+ import sys
946
+
947
+ for name, func in mapping.items():
948
+ # Sanitize filename
949
+ filename = re.sub(r'[^a-zA-Z0-9]', '_', name.lower()).strip('_')
950
+ filename = re.sub(r'_+', '_', filename) + ".png"
951
+ filepath = os.path.join(output_dir, filename)
952
+
953
+ print(f"Generating: {filename} ...")
954
+
955
+ plt.close('all')
956
+
957
+ if func == plot_reward_landscape:
958
+ fig = plt.figure(figsize=(10, 8))
959
+ gs = GridSpec(1, 1, figure=fig)
960
+ func(fig, gs)
961
+ plt.savefig(filepath, bbox_inches='tight', dpi=100)
962
+ plt.close(fig)
963
+ continue
964
+
965
+ fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True)
966
+ func(ax)
967
+ plt.savefig(filepath, bbox_inches='tight', dpi=100)
968
+ plt.close(fig)
969
+
970
+ print(f"\n[SUCCESS] Saved {len(mapping)} graphs to '{output_dir}/' directory.")
971
+
972
+ if __name__ == "__main__":
973
+ import sys
974
+ if "--save" in sys.argv:
975
+ save_all_graphs()
976
+ else:
977
+ main()
graphs/action_value_function_q_s_a.png ADDED
graphs/actor_critic_architecture.png ADDED
graphs/advantage_actor_critic_a2c_a3c.png ADDED
graphs/advantage_function_a_s_a.png ADDED
graphs/agent_environment_interaction_loop.png ADDED
graphs/attention_mechanisms_transformers_in_rl.png ADDED
graphs/baseline_advantage_subtraction.png ADDED
graphs/bootstrapping_general.png ADDED
graphs/centralized_training_decentralized_execution_ctde.png ADDED
graphs/computation_graph_backpropagation_flow.png ADDED
graphs/conservative_q_learning_cql.png ADDED
graphs/continuous_state_action_space_visualization.png ADDED
graphs/convergence_analysis_plots.png ADDED
graphs/cooperative_competitive_payoff_matrix.png ADDED
graphs/diffusion_policy.png ADDED
graphs/discount_factor_gamma_effect.png ADDED
graphs/double_q_learning_double_dqn.png ADDED
graphs/dueling_dqn_architecture.png ADDED
graphs/entropy_regularization.png ADDED
graphs/epsilon_greedy_strategy.png ADDED
graphs/expected_sarsa.png ADDED
graphs/experience_replay_buffer.png ADDED
graphs/feudal_networks_hierarchical_actor_critic.png ADDED
graphs/generative_adversarial_imitation_learning_gail.png ADDED
graphs/graph_neural_networks_for_rl.png ADDED
graphs/imagination_augmented_agents_i2a.png ADDED
graphs/importance_sampling_ratio.png ADDED
graphs/intrinsic_motivation_curiosity.png ADDED
graphs/learned_dynamics_model.png ADDED
graphs/learning_curve.png ADDED
graphs/linear_function_approximation.png ADDED
graphs/markov_decision_process_mdp_tuple.png ADDED
graphs/meta_rl_architecture.png ADDED
graphs/model_based_planning.png ADDED
graphs/monte_carlo_backup.png ADDED
graphs/monte_carlo_tree_mcts.png ADDED
graphs/multi_agent_interaction_graph.png ADDED
graphs/n_step_td_backup.png ADDED
graphs/neural_network_layers_mlp_cnn_rnn_transformer.png ADDED
graphs/offline_dataset.png ADDED
graphs/optimal_value_function_v_q.png ADDED
graphs/options_framework.png ADDED
graphs/policy_evaluation_backup.png ADDED
graphs/policy_gradient_theorem.png ADDED
graphs/policy_improvement.png ADDED
graphs/policy_iteration_full_cycle.png ADDED
graphs/policy_pi_s_or_pi_a_s.png ADDED