Upload 76 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- README.md +80 -3
- core.py +977 -0
- graphs/action_value_function_q_s_a.png +0 -0
- graphs/actor_critic_architecture.png +0 -0
- graphs/advantage_actor_critic_a2c_a3c.png +0 -0
- graphs/advantage_function_a_s_a.png +0 -0
- graphs/agent_environment_interaction_loop.png +0 -0
- graphs/attention_mechanisms_transformers_in_rl.png +0 -0
- graphs/baseline_advantage_subtraction.png +0 -0
- graphs/bootstrapping_general.png +0 -0
- graphs/centralized_training_decentralized_execution_ctde.png +0 -0
- graphs/computation_graph_backpropagation_flow.png +0 -0
- graphs/conservative_q_learning_cql.png +0 -0
- graphs/continuous_state_action_space_visualization.png +0 -0
- graphs/convergence_analysis_plots.png +0 -0
- graphs/cooperative_competitive_payoff_matrix.png +0 -0
- graphs/diffusion_policy.png +0 -0
- graphs/discount_factor_gamma_effect.png +0 -0
- graphs/double_q_learning_double_dqn.png +0 -0
- graphs/dueling_dqn_architecture.png +0 -0
- graphs/entropy_regularization.png +0 -0
- graphs/epsilon_greedy_strategy.png +0 -0
- graphs/expected_sarsa.png +0 -0
- graphs/experience_replay_buffer.png +0 -0
- graphs/feudal_networks_hierarchical_actor_critic.png +0 -0
- graphs/generative_adversarial_imitation_learning_gail.png +0 -0
- graphs/graph_neural_networks_for_rl.png +0 -0
- graphs/imagination_augmented_agents_i2a.png +0 -0
- graphs/importance_sampling_ratio.png +0 -0
- graphs/intrinsic_motivation_curiosity.png +0 -0
- graphs/learned_dynamics_model.png +0 -0
- graphs/learning_curve.png +0 -0
- graphs/linear_function_approximation.png +0 -0
- graphs/markov_decision_process_mdp_tuple.png +0 -0
- graphs/meta_rl_architecture.png +0 -0
- graphs/model_based_planning.png +0 -0
- graphs/monte_carlo_backup.png +0 -0
- graphs/monte_carlo_tree_mcts.png +0 -0
- graphs/multi_agent_interaction_graph.png +0 -0
- graphs/n_step_td_backup.png +0 -0
- graphs/neural_network_layers_mlp_cnn_rnn_transformer.png +0 -0
- graphs/offline_dataset.png +0 -0
- graphs/optimal_value_function_v_q.png +0 -0
- graphs/options_framework.png +0 -0
- graphs/policy_evaluation_backup.png +0 -0
- graphs/policy_gradient_theorem.png +0 -0
- graphs/policy_improvement.png +0 -0
- graphs/policy_iteration_full_cycle.png +0 -0
- graphs/policy_pi_s_or_pi_a_s.png +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
graphs/reward_function_landscape.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,80 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---title: Reinforcement Learning Graphical Representationsdate: 2026-04-08category: Reinforcement Learningdescription: A comprehensive gallery of 72 standard RL components and their graphical presentations.---
|
| 2 |
+
|
| 3 |
+
# Reinforcement Learning Graphical Representations
|
| 4 |
+
|
| 5 |
+
This repository contains a full set of 72 visualizations representing foundational concepts, algorithms, and advanced topics in Reinforcement Learning.
|
| 6 |
+
|
| 7 |
+
| Category | Component | Illustration | Details | Context |
|
| 8 |
+
|----------|-----------|--------------|---------|---------|
|
| 9 |
+
| **MDP & Environment** | **Agent-Environment Interaction Loop** |  | Core cycle: observation of state → selection of action → environment transition → receipt of reward + next state | All RL algorithms |
|
| 10 |
+
| **MDP & Environment** | **Markov Decision Process (MDP) Tuple** |  | (S, A, P, R, γ) with transition dynamics and reward function | s,a) and R(s,a,s′)) |
|
| 11 |
+
| **MDP & Environment** | **State Transition Graph** |  | Full probabilistic transitions between discrete states | Gridworld, Taxi, Cliff Walking |
|
| 12 |
+
| **MDP & Environment** | **Trajectory / Episode Sequence** |  | Sequence of (s₀, a₀, r₁, s₁, …, s_T) | Monte Carlo, episodic tasks |
|
| 13 |
+
| **MDP & Environment** | **Continuous State/Action Space Visualization** |  | High-dimensional spaces (e.g., robot joints, pixel inputs) | Continuous-control tasks (MuJoCo, PyBullet) |
|
| 14 |
+
| **MDP & Environment** | **Reward Function / Landscape** |  | Scalar reward as function of state/action | All algorithms; especially reward shaping |
|
| 15 |
+
| **MDP & Environment** | **Discount Factor (γ) Effect** |  | How future rewards are weighted | All discounted MDPs |
|
| 16 |
+
| **Value & Policy** | **State-Value Function V(s)** |  | Expected return from state s under policy π | Value-based methods |
|
| 17 |
+
| **Value & Policy** | **Action-Value Function Q(s,a)** |  | Expected return from state-action pair | Q-learning family |
|
| 18 |
+
| **Value & Policy** | **Policy π(s) or π(a\** |  | s) | Arrow overlays on grid (optimal policy), probability bar charts, or softmax heatmaps |
|
| 19 |
+
| **Value & Policy** | **Advantage Function A(s,a)** |  | Q(s,a) – V(s) | A2C, PPO, SAC, TD3 |
|
| 20 |
+
| **Value & Policy** | **Optimal Value Function V* / Q*** |  | Solution to Bellman optimality | Value iteration, Q-learning |
|
| 21 |
+
| **Dynamic Programming** | **Policy Evaluation Backup** |  | Iterative update of V using Bellman expectation | Policy iteration |
|
| 22 |
+
| **Dynamic Programming** | **Policy Improvement** |  | Greedy policy update over Q | Policy iteration |
|
| 23 |
+
| **Dynamic Programming** | **Value Iteration Backup** |  | Update using Bellman optimality | Value iteration |
|
| 24 |
+
| **Dynamic Programming** | **Policy Iteration Full Cycle** |  | Evaluation → Improvement loop | Classic DP methods |
|
| 25 |
+
| **Monte Carlo** | **Monte Carlo Backup** |  | Update using full episode return G_t | First-visit / every-visit MC |
|
| 26 |
+
| **Monte Carlo** | **Monte Carlo Tree (MCTS)** |  | Search tree with selection, expansion, simulation, backprop | AlphaGo, AlphaZero |
|
| 27 |
+
| **Monte Carlo** | **Importance Sampling Ratio** |  | Off-policy correction ρ = π(a\ | s) |
|
| 28 |
+
| **Temporal Difference** | **TD(0) Backup** |  | Bootstrapped update using R + γV(s′) | TD learning |
|
| 29 |
+
| **Temporal Difference** | **Bootstrapping (general)** |  | Using estimated future value instead of full return | All TD methods |
|
| 30 |
+
| **Temporal Difference** | **n-step TD Backup** |  | Multi-step return G_t^{(n)} | n-step TD, TD(λ) |
|
| 31 |
+
| **Temporal Difference** | **TD(λ) & Eligibility Traces** |  | Decaying trace z_t for credit assignment | TD(λ), SARSA(λ), Q(λ) |
|
| 32 |
+
| **Temporal Difference** | **SARSA Update** |  | On-policy TD control | SARSA |
|
| 33 |
+
| **Temporal Difference** | **Q-Learning Update** |  | Off-policy TD control | Q-learning, Deep Q-Network |
|
| 34 |
+
| **Temporal Difference** | **Expected SARSA** |  | Expectation over next action under policy | Expected SARSA |
|
| 35 |
+
| **Temporal Difference** | **Double Q-Learning / Double DQN** |  | Two separate Q estimators to reduce overestimation | Double DQN, TD3 |
|
| 36 |
+
| **Temporal Difference** | **Dueling DQN Architecture** |  | Separate streams for state value V(s) and advantage A(s,a) | Dueling DQN |
|
| 37 |
+
| **Temporal Difference** | **Prioritized Experience Replay** |  | Importance sampling of transitions by TD error | Prioritized DQN, Rainbow |
|
| 38 |
+
| **Temporal Difference** | **Rainbow DQN Components** |  | All extensions combined (Double, Dueling, PER, etc.) | Rainbow DQN |
|
| 39 |
+
| **Function Approximation** | **Linear Function Approximation** |  | Feature vector φ(s) → wᵀφ(s) | Tabular → linear FA |
|
| 40 |
+
| **Function Approximation** | **Neural Network Layers (MLP, CNN, RNN, Transformer)** |  | Full deep network for value/policy | DQN, A3C, PPO, Decision Transformer |
|
| 41 |
+
| **Function Approximation** | **Computation Graph / Backpropagation Flow** |  | Gradient flow through network | All deep RL |
|
| 42 |
+
| **Function Approximation** | **Target Network** |  | Frozen copy of Q-network for stability | DQN, DDQN, SAC, TD3 |
|
| 43 |
+
| **Policy Gradients** | **Policy Gradient Theorem** |  | ∇_θ J(θ) = E[∇_θ log π(a\ | Flow diagram from reward → log-prob → gradient |
|
| 44 |
+
| **Policy Gradients** | **REINFORCE Update** |  | Monte-Carlo policy gradient | REINFORCE |
|
| 45 |
+
| **Policy Gradients** | **Baseline / Advantage Subtraction** |  | Subtract b(s) to reduce variance | All modern PG |
|
| 46 |
+
| **Policy Gradients** | **Trust Region (TRPO)** |  | KL-divergence constraint on policy update | TRPO |
|
| 47 |
+
| **Policy Gradients** | **Proximal Policy Optimization (PPO)** |  | Clipped surrogate objective | PPO, PPO-Clip |
|
| 48 |
+
| **Actor-Critic** | **Actor-Critic Architecture** |  | Separate or shared actor (policy) + critic (value) networks | A2C, A3C, SAC, TD3 |
|
| 49 |
+
| **Actor-Critic** | **Advantage Actor-Critic (A2C/A3C)** |  | Synchronous/asynchronous multi-worker | A2C/A3C |
|
| 50 |
+
| **Actor-Critic** | **Soft Actor-Critic (SAC)** |  | Entropy-regularized policy + twin critics | SAC |
|
| 51 |
+
| **Actor-Critic** | **Twin Delayed DDPG (TD3)** |  | Twin critics + delayed policy + target smoothing | TD3 |
|
| 52 |
+
| **Exploration** | **ε-Greedy Strategy** |  | Probability ε of random action | DQN family |
|
| 53 |
+
| **Exploration** | **Softmax / Boltzmann Exploration** |  | Temperature τ in softmax | Softmax policies |
|
| 54 |
+
| **Exploration** | **Upper Confidence Bound (UCB)** |  | Optimism in face of uncertainty | UCB1, bandits |
|
| 55 |
+
| **Exploration** | **Intrinsic Motivation / Curiosity** |  | Prediction error as intrinsic reward | ICM, RND, Curiosity-driven RL |
|
| 56 |
+
| **Exploration** | **Entropy Regularization** |  | Bonus term αH(π) | SAC, maximum-entropy RL |
|
| 57 |
+
| **Hierarchical RL** | **Options Framework** |  | High-level policy over options (temporally extended actions) | Option-Critic |
|
| 58 |
+
| **Hierarchical RL** | **Feudal Networks / Hierarchical Actor-Critic** |  | Manager-worker hierarchy | Feudal RL |
|
| 59 |
+
| **Hierarchical RL** | **Skill Discovery** |  | Unsupervised emergence of reusable skills | DIAYN, VALOR |
|
| 60 |
+
| **Model-Based RL** | **Learned Dynamics Model** |  | ˆP(s′\ | Separate model network diagram (often RNN or transformer) |
|
| 61 |
+
| **Model-Based RL** | **Model-Based Planning** |  | Rollouts inside learned model | MuZero, DreamerV3 |
|
| 62 |
+
| **Model-Based RL** | **Imagination-Augmented Agents (I2A)** |  | Imagination module + policy | I2A |
|
| 63 |
+
| **Offline RL** | **Offline Dataset** |  | Fixed batch of trajectories | BC, CQL, IQL |
|
| 64 |
+
| **Offline RL** | **Conservative Q-Learning (CQL)** |  | Penalty on out-of-distribution actions | CQL |
|
| 65 |
+
| **Multi-Agent RL** | **Multi-Agent Interaction Graph** |  | Agents communicating or competing | MARL, MADDPG |
|
| 66 |
+
| **Multi-Agent RL** | **Centralized Training Decentralized Execution (CTDE)** |  | Shared critic during training | QMIX, VDN, MADDPG |
|
| 67 |
+
| **Multi-Agent RL** | **Cooperative / Competitive Payoff Matrix** |  | Joint reward for multiple agents | Prisoner's Dilemma, multi-agent gridworlds |
|
| 68 |
+
| **Inverse RL / IRL** | **Reward Inference** |  | Infer reward from expert demonstrations | IRL, GAIL |
|
| 69 |
+
| **Inverse RL / IRL** | **Generative Adversarial Imitation Learning (GAIL)** |  | Discriminator vs. policy generator | GAIL, AIRL |
|
| 70 |
+
| **Meta-RL** | **Meta-RL Architecture** |  | Outer loop (meta-policy) + inner loop (task adaptation) | MAML for RL, RL² |
|
| 71 |
+
| **Meta-RL** | **Task Distribution Visualization** |  | Multiple MDPs sampled from meta-distribution | Meta-RL benchmarks |
|
| 72 |
+
| **Advanced / Misc** | **Experience Replay Buffer** |  | Stored (s,a,r,s′,done) tuples | DQN and all off-policy deep RL |
|
| 73 |
+
| **Advanced / Misc** | **State Visitation / Occupancy Measure** |  | Frequency of visiting each state | All algorithms (analysis) |
|
| 74 |
+
| **Advanced / Misc** | **Learning Curve** |  | Average episodic return vs. episodes / steps | Standard performance reporting |
|
| 75 |
+
| **Advanced / Misc** | **Regret / Cumulative Regret** |  | Sub-optimality accumulated | Bandits and online RL |
|
| 76 |
+
| **Advanced / Misc** | **Attention Mechanisms (Transformers in RL)** |  | Attention weights | Decision Transformer, Trajectory Transformer |
|
| 77 |
+
| **Advanced / Misc** | **Diffusion Policy** |  | Denoising diffusion process for action generation | Diffusion-RL policies |
|
| 78 |
+
| **Advanced / Misc** | **Graph Neural Networks for RL** |  | Node/edge message passing | Graph RL, relational RL |
|
| 79 |
+
| **Advanced / Misc** | **World Model / Latent Space** |  | Encoder-decoder dynamics in latent space | Dreamer, PlaNet |
|
| 80 |
+
| **Advanced / Misc** | **Convergence Analysis Plots** |  | Error / value change over iterations | DP, TD, value iteration |
|
core.py
ADDED
|
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import networkx as nx
|
| 4 |
+
from matplotlib.gridspec import GridSpec
|
| 5 |
+
from matplotlib.patches import FancyArrowPatch
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
def setup_figure(title, rows, cols):
|
| 11 |
+
"""Initializes a new figure and grid layout with constrained_layout to avoid warnings."""
|
| 12 |
+
fig = plt.figure(figsize=(20, 10), constrained_layout=True)
|
| 13 |
+
fig.suptitle(title, fontsize=18, fontweight='bold')
|
| 14 |
+
gs = GridSpec(rows, cols, figure=fig)
|
| 15 |
+
return fig, gs
|
| 16 |
+
|
| 17 |
+
def plot_agent_env_loop(ax):
|
| 18 |
+
"""MDP & Environment: Agent-Environment Interaction Loop (Flowchart)."""
|
| 19 |
+
ax.axis('off')
|
| 20 |
+
ax.set_title("Agent-Environment Interaction", fontsize=12, fontweight='bold')
|
| 21 |
+
|
| 22 |
+
props = dict(boxstyle="round,pad=0.8", fc="ivory", ec="black", lw=1.5)
|
| 23 |
+
ax.text(0.5, 0.8, "Agent", ha="center", va="center", bbox=props, fontsize=12)
|
| 24 |
+
ax.text(0.5, 0.2, "Environment", ha="center", va="center", bbox=props, fontsize=12)
|
| 25 |
+
|
| 26 |
+
# Arrows
|
| 27 |
+
# Agent to Env: Action
|
| 28 |
+
ax.annotate("Action $A_t$", xy=(0.5, 0.35), xytext=(0.5, 0.65),
|
| 29 |
+
arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2))
|
| 30 |
+
# Env to Agent: State & Reward
|
| 31 |
+
ax.annotate("State $S_{t+1}$, Reward $R_{t+1}$", xy=(0.5, 0.65), xytext=(0.5, 0.35),
|
| 32 |
+
arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5", lw=2, color='green'))
|
| 33 |
+
|
| 34 |
+
def plot_mdp_graph(ax):
|
| 35 |
+
"""MDP & Environment: Directed graph with probability-weighted arrows."""
|
| 36 |
+
G = nx.DiGraph()
|
| 37 |
+
# Corrected syntax: using a dictionary for edge attributes
|
| 38 |
+
G.add_edges_from([
|
| 39 |
+
('S0', 'S1', {'weight': 0.8}), ('S0', 'S2', {'weight': 0.2}),
|
| 40 |
+
('S1', 'S2', {'weight': 1.0}), ('S2', 'S0', {'weight': 0.5}), ('S2', 'S2', {'weight': 0.5})
|
| 41 |
+
])
|
| 42 |
+
pos = nx.spring_layout(G, seed=42)
|
| 43 |
+
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=1500, node_color='lightblue')
|
| 44 |
+
nx.draw_networkx_labels(ax=ax, G=G, pos=pos, font_weight='bold')
|
| 45 |
+
|
| 46 |
+
edge_labels = {(u, v): f"P={d['weight']}" for u, v, d in G.edges(data=True)}
|
| 47 |
+
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrowsize=20, edge_color='gray', connectionstyle="arc3,rad=0.1")
|
| 48 |
+
nx.draw_networkx_edge_labels(ax=ax, G=G, pos=pos, edge_labels=edge_labels, font_size=9)
|
| 49 |
+
ax.set_title("MDP State Transition Graph", fontsize=12, fontweight='bold')
|
| 50 |
+
ax.axis('off')
|
| 51 |
+
|
| 52 |
+
def plot_reward_landscape(fig, gs):
|
| 53 |
+
"""MDP & Environment: 3D surface plot of a reward function."""
|
| 54 |
+
# Use the first available slot in gs (handled flexibly for dashboard vs save)
|
| 55 |
+
try:
|
| 56 |
+
ax = fig.add_subplot(gs[0, 1], projection='3d')
|
| 57 |
+
except IndexError:
|
| 58 |
+
ax = fig.add_subplot(gs[0, 0], projection='3d')
|
| 59 |
+
X = np.linspace(-5, 5, 50)
|
| 60 |
+
Y = np.linspace(-5, 5, 50)
|
| 61 |
+
X, Y = np.meshgrid(X, Y)
|
| 62 |
+
Z = np.sin(np.sqrt(X**2 + Y**2)) + (X * 0.1) # Simulated reward landscape
|
| 63 |
+
|
| 64 |
+
surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.9)
|
| 65 |
+
ax.set_title("Reward Function Landscape", fontsize=12, fontweight='bold')
|
| 66 |
+
ax.set_xlabel('State X')
|
| 67 |
+
ax.set_ylabel('State Y')
|
| 68 |
+
ax.set_zlabel('Reward R(s)')
|
| 69 |
+
|
| 70 |
+
def plot_trajectory(ax):
|
| 71 |
+
"""MDP & Environment: Trajectory / Episode Sequence."""
|
| 72 |
+
ax.set_title("Trajectory Sequence", fontsize=12, fontweight='bold')
|
| 73 |
+
states = ['s0', 's1', 's2', 's3', 'sT']
|
| 74 |
+
actions = ['a0', 'a1', 'a2', 'a3']
|
| 75 |
+
rewards = ['r1', 'r2', 'r3', 'r4']
|
| 76 |
+
|
| 77 |
+
for i, s in enumerate(states):
|
| 78 |
+
ax.text(i, 0.5, s, ha='center', va='center', bbox=dict(boxstyle="circle", fc="white"))
|
| 79 |
+
if i < len(actions):
|
| 80 |
+
ax.annotate("", xy=(i+0.8, 0.5), xytext=(i+0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 81 |
+
ax.text(i+0.5, 0.6, actions[i], ha='center', color='blue')
|
| 82 |
+
ax.text(i+0.5, 0.4, rewards[i], ha='center', color='red')
|
| 83 |
+
|
| 84 |
+
ax.set_xlim(-0.5, len(states)-0.5)
|
| 85 |
+
ax.set_ylim(0, 1)
|
| 86 |
+
ax.axis('off')
|
| 87 |
+
|
| 88 |
+
def plot_continuous_space(ax):
|
| 89 |
+
"""MDP & Environment: Continuous State/Action Space Visualization."""
|
| 90 |
+
np.random.seed(42)
|
| 91 |
+
x = np.random.randn(200, 2)
|
| 92 |
+
labels = np.linalg.norm(x, axis=1) > 1.0
|
| 93 |
+
ax.scatter(x[labels, 0], x[labels, 1], c='coral', alpha=0.6, label='High Reward')
|
| 94 |
+
ax.scatter(x[~labels, 0], x[~labels, 1], c='skyblue', alpha=0.6, label='Low Reward')
|
| 95 |
+
ax.set_title("Continuous State Space (2D Projection)", fontsize=12, fontweight='bold')
|
| 96 |
+
ax.legend(fontsize=8)
|
| 97 |
+
|
| 98 |
+
def plot_discount_decay(ax):
|
| 99 |
+
"""MDP & Environment: Discount Factor (gamma) Effect."""
|
| 100 |
+
t = np.arange(0, 20)
|
| 101 |
+
for gamma in [0.5, 0.9, 0.99]:
|
| 102 |
+
ax.plot(t, gamma**t, marker='o', markersize=4, label=rf"$\gamma={gamma}$")
|
| 103 |
+
ax.set_title(r"Discount Factor $\gamma^t$ Decay", fontsize=12, fontweight='bold')
|
| 104 |
+
ax.set_xlabel("Time steps (t)")
|
| 105 |
+
ax.set_ylabel("Weight")
|
| 106 |
+
ax.legend()
|
| 107 |
+
ax.grid(True, alpha=0.3)
|
| 108 |
+
|
| 109 |
+
def plot_value_heatmap(ax):
|
| 110 |
+
"""Value & Policy: State-Value Function V(s) Heatmap (Gridworld)."""
|
| 111 |
+
grid_size = 5
|
| 112 |
+
# Simulate a value landscape where the top right is the goal
|
| 113 |
+
values = np.zeros((grid_size, grid_size))
|
| 114 |
+
for i in range(grid_size):
|
| 115 |
+
for j in range(grid_size):
|
| 116 |
+
values[i, j] = -( (grid_size-1-i)**2 + (grid_size-1-j)**2 ) * 0.5
|
| 117 |
+
values[-1, -1] = 10.0 # Goal state
|
| 118 |
+
|
| 119 |
+
cax = ax.matshow(values, cmap='magma')
|
| 120 |
+
for (i, j), z in np.ndenumerate(values):
|
| 121 |
+
ax.text(j, i, f'{z:0.1f}', ha='center', va='center', color='white' if z < -5 else 'black', fontsize=9)
|
| 122 |
+
|
| 123 |
+
ax.set_title("State-Value Function V(s) Heatmap", fontsize=12, fontweight='bold', pad=15)
|
| 124 |
+
ax.set_xticks(range(grid_size))
|
| 125 |
+
ax.set_yticks(range(grid_size))
|
| 126 |
+
|
| 127 |
+
def plot_backup_diagram(ax):
|
| 128 |
+
"""Dynamic Programming: Policy Evaluation Backup Diagram."""
|
| 129 |
+
G = nx.DiGraph()
|
| 130 |
+
G.add_node("s", layer=0)
|
| 131 |
+
G.add_node("a1", layer=1); G.add_node("a2", layer=1)
|
| 132 |
+
G.add_node("s'_1", layer=2); G.add_node("s'_2", layer=2); G.add_node("s'_3", layer=2)
|
| 133 |
+
|
| 134 |
+
G.add_edges_from([("s", "a1"), ("s", "a2")])
|
| 135 |
+
G.add_edges_from([("a1", "s'_1"), ("a1", "s'_2"), ("a2", "s'_3")])
|
| 136 |
+
|
| 137 |
+
pos = {
|
| 138 |
+
"s": (0.5, 1),
|
| 139 |
+
"a1": (0.25, 0.5), "a2": (0.75, 0.5),
|
| 140 |
+
"s'_1": (0.1, 0), "s'_2": (0.4, 0), "s'_3": (0.75, 0)
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["s", "s'_1", "s'_2", "s'_3"], node_size=800, node_color='white', edgecolors='black')
|
| 144 |
+
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, nodelist=["a1", "a2"], node_size=300, node_color='black') # Action nodes are solid black dots
|
| 145 |
+
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
|
| 146 |
+
nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "s'_1": "s'", "s'_2": "s'", "s'_3": "s'"}, font_size=10)
|
| 147 |
+
|
| 148 |
+
ax.set_title("DP Policy Eval Backup", fontsize=12, fontweight='bold')
|
| 149 |
+
ax.set_ylim(-0.2, 1.2)
|
| 150 |
+
ax.axis('off')
|
| 151 |
+
|
| 152 |
+
def plot_action_value_q(ax):
|
| 153 |
+
"""Value & Policy: Action-Value Function Q(s,a) (Heatmap per action stack)."""
|
| 154 |
+
grid = np.random.rand(3, 3)
|
| 155 |
+
ax.imshow(grid, cmap='YlGnBu')
|
| 156 |
+
for (i, j), z in np.ndenumerate(grid):
|
| 157 |
+
ax.text(j, i, f'{z:0.1f}', ha='center', va='center', fontsize=8)
|
| 158 |
+
ax.set_title(r"Action-Value $Q(s, a_{up})$", fontsize=12, fontweight='bold')
|
| 159 |
+
ax.set_xticks([]); ax.set_yticks([])
|
| 160 |
+
|
| 161 |
+
def plot_policy_arrows(ax):
|
| 162 |
+
"""Value & Policy: Policy π(s) as arrow overlays on grid."""
|
| 163 |
+
grid_size = 4
|
| 164 |
+
ax.set_xlim(-0.5, grid_size-0.5)
|
| 165 |
+
ax.set_ylim(-0.5, grid_size-0.5)
|
| 166 |
+
for i in range(grid_size):
|
| 167 |
+
for j in range(grid_size):
|
| 168 |
+
dx, dy = np.random.choice([0, 0.3, -0.3]), np.random.choice([0, 0.3, -0.3])
|
| 169 |
+
if dx == 0 and dy == 0: dx = 0.3
|
| 170 |
+
ax.add_patch(FancyArrowPatch((j, i), (j+dx, i+dy), arrowstyle='->', mutation_scale=15))
|
| 171 |
+
ax.set_title(r"Policy $\pi(s)$ Arrows", fontsize=12, fontweight='bold')
|
| 172 |
+
ax.set_xticks(range(grid_size)); ax.set_yticks(range(grid_size)); ax.grid(True, alpha=0.2)
|
| 173 |
+
|
| 174 |
+
def plot_advantage_function(ax):
|
| 175 |
+
"""Value & Policy: Advantage Function A(s,a) = Q-V."""
|
| 176 |
+
actions = ['A1', 'A2', 'A3', 'A4']
|
| 177 |
+
advantage = [2.1, -1.2, 0.5, -0.8]
|
| 178 |
+
colors = ['green' if v > 0 else 'red' for v in advantage]
|
| 179 |
+
ax.bar(actions, advantage, color=colors, alpha=0.7)
|
| 180 |
+
ax.axhline(0, color='black', lw=1)
|
| 181 |
+
ax.set_title(r"Advantage $A(s, a)$", fontsize=12, fontweight='bold')
|
| 182 |
+
ax.set_ylabel("Value")
|
| 183 |
+
|
| 184 |
+
def plot_policy_improvement(ax):
|
| 185 |
+
"""Dynamic Programming: Policy Improvement (Before vs After)."""
|
| 186 |
+
ax.axis('off')
|
| 187 |
+
ax.set_title("Policy Improvement", fontsize=12, fontweight='bold')
|
| 188 |
+
ax.text(0.2, 0.5, r"$\pi_{old}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgrey"))
|
| 189 |
+
ax.annotate("", xy=(0.8, 0.5), xytext=(0.3, 0.5), arrowprops=dict(arrowstyle="->", lw=2))
|
| 190 |
+
ax.text(0.5, 0.6, "Greedy\nImprovement", ha='center', fontsize=9)
|
| 191 |
+
ax.text(0.85, 0.5, r"$\pi_{new}$", fontsize=15, bbox=dict(boxstyle="round", fc="lightgreen"))
|
| 192 |
+
|
| 193 |
+
def plot_value_iteration_backup(ax):
|
| 194 |
+
"""Dynamic Programming: Value Iteration Backup Diagram (Max over actions)."""
|
| 195 |
+
G = nx.DiGraph()
|
| 196 |
+
pos = {"s": (0.5, 1), "max": (0.5, 0.5), "s1": (0.2, 0), "s2": (0.5, 0), "s3": (0.8, 0)}
|
| 197 |
+
G.add_nodes_from(pos.keys())
|
| 198 |
+
G.add_edges_from([("s", "max"), ("max", "s1"), ("max", "s2"), ("max", "s3")])
|
| 199 |
+
|
| 200 |
+
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color='white', edgecolors='black')
|
| 201 |
+
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
|
| 202 |
+
nx.draw_networkx_labels(ax=ax, G=G, pos=pos, labels={"s": "s", "max": "max", "s1": "s'", "s2": "s'", "s3": "s'"}, font_size=9)
|
| 203 |
+
ax.set_title("Value Iteration Backup", fontsize=12, fontweight='bold')
|
| 204 |
+
ax.axis('off')
|
| 205 |
+
|
| 206 |
+
def plot_policy_iteration_cycle(ax):
|
| 207 |
+
"""Dynamic Programming: Policy Iteration Full Cycle Flowchart."""
|
| 208 |
+
ax.axis('off')
|
| 209 |
+
ax.set_title("Policy Iteration Cycle", fontsize=12, fontweight='bold')
|
| 210 |
+
props = dict(boxstyle="round", fc="aliceblue", ec="black")
|
| 211 |
+
ax.text(0.5, 0.8, r"Policy Evaluation" + "\n" + r"$V \leftarrow V^\pi$", ha="center", bbox=props)
|
| 212 |
+
ax.text(0.5, 0.2, r"Policy Improvement" + "\n" + r"$\pi \leftarrow \text{greedy}(V)$", ha="center", bbox=props)
|
| 213 |
+
ax.annotate("", xy=(0.7, 0.3), xytext=(0.7, 0.7), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
|
| 214 |
+
ax.annotate("", xy=(0.3, 0.7), xytext=(0.3, 0.3), arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.5"))
|
| 215 |
+
|
| 216 |
+
def plot_mc_backup(ax):
|
| 217 |
+
"""Monte Carlo: Backup diagram (Full trajectory until terminal sT)."""
|
| 218 |
+
ax.axis('off')
|
| 219 |
+
ax.set_title("Monte Carlo Backup", fontsize=12, fontweight='bold')
|
| 220 |
+
nodes = ['s', 's1', 's2', 'sT']
|
| 221 |
+
pos = {n: (0.5, 0.9 - i*0.25) for i, n in enumerate(nodes)}
|
| 222 |
+
for i in range(len(nodes)-1):
|
| 223 |
+
ax.annotate("", xy=pos[nodes[i+1]], xytext=pos[nodes[i]], arrowprops=dict(arrowstyle="->", lw=1.5))
|
| 224 |
+
ax.text(pos[nodes[i]][0]+0.05, pos[nodes[i]][1], nodes[i], va='center')
|
| 225 |
+
ax.text(pos['sT'][0]+0.05, pos['sT'][1], 'sT', va='center', fontweight='bold')
|
| 226 |
+
ax.annotate("Update V(s) using G", xy=(0.3, 0.9), xytext=(0.3, 0.15), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.3"))
|
| 227 |
+
|
| 228 |
+
def plot_mcts(ax):
|
| 229 |
+
"""Monte Carlo: Monte Carlo Tree Search (MCTS) tree diagram."""
|
| 230 |
+
G = nx.balanced_tree(2, 2, create_using=nx.DiGraph())
|
| 231 |
+
pos = nx.drawing.nx_agraph.graphviz_layout(G, prog='dot') if 'pygraphviz' in globals() else nx.shell_layout(G)
|
| 232 |
+
# Simple tree fallback
|
| 233 |
+
pos = {0:(0,0), 1:(-1,-1), 2:(1,-1), 3:(-1.5,-2), 4:(-0.5,-2), 5:(0.5,-2), 6:(1.5,-2)}
|
| 234 |
+
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=300, node_color='lightyellow', edgecolors='black')
|
| 235 |
+
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, arrows=True)
|
| 236 |
+
ax.set_title("MCTS Tree", fontsize=12, fontweight='bold')
|
| 237 |
+
ax.axis('off')
|
| 238 |
+
|
| 239 |
+
def plot_importance_sampling(ax):
|
| 240 |
+
"""Monte Carlo: Importance Sampling Ratio Flow."""
|
| 241 |
+
ax.axis('off')
|
| 242 |
+
ax.set_title("Importance Sampling", fontsize=12, fontweight='bold')
|
| 243 |
+
ax.text(0.5, 0.8, r"$\pi(a|s)$", bbox=dict(boxstyle="circle", fc="lightgreen"), ha='center')
|
| 244 |
+
ax.text(0.5, 0.2, r"$b(a|s)$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
|
| 245 |
+
ax.annotate(r"$\rho = \frac{\pi}{b}$", xy=(0.7, 0.5), fontsize=15)
|
| 246 |
+
ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<->", lw=2))
|
| 247 |
+
|
| 248 |
+
def plot_td_backup(ax):
|
| 249 |
+
"""Temporal Difference: TD(0) 1-step backup."""
|
| 250 |
+
ax.axis('off')
|
| 251 |
+
ax.set_title("TD(0) Backup", fontsize=12, fontweight='bold')
|
| 252 |
+
ax.text(0.5, 0.8, "s", bbox=dict(boxstyle="circle", fc="white"), ha='center')
|
| 253 |
+
ax.text(0.5, 0.2, "s'", bbox=dict(boxstyle="circle", fc="white"), ha='center')
|
| 254 |
+
ax.annotate(r"$R + \gamma V(s')$", xy=(0.5, 0.4), ha='center', color='blue')
|
| 255 |
+
ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="<-", lw=2))
|
| 256 |
+
|
| 257 |
+
def plot_nstep_td(ax):
|
| 258 |
+
"""Temporal Difference: n-step TD backup."""
|
| 259 |
+
ax.axis('off')
|
| 260 |
+
ax.set_title("n-step TD Backup", fontsize=12, fontweight='bold')
|
| 261 |
+
for i in range(4):
|
| 262 |
+
ax.text(0.5, 0.9-i*0.2, f"s_{i}", bbox=dict(boxstyle="circle", fc="white"), ha='center', fontsize=8)
|
| 263 |
+
if i < 3: ax.annotate("", xy=(0.5, 0.75-i*0.2), xytext=(0.5, 0.85-i*0.2), arrowprops=dict(arrowstyle="->"))
|
| 264 |
+
ax.annotate(r"$G_t^{(n)}$", xy=(0.7, 0.5), fontsize=12, color='red')
|
| 265 |
+
|
| 266 |
+
def plot_eligibility_traces(ax):
|
| 267 |
+
"""Temporal Difference: TD(lambda) Eligibility Traces decay curve."""
|
| 268 |
+
t = np.arange(0, 50)
|
| 269 |
+
# Simulate multiple highlights (visits)
|
| 270 |
+
trace = np.zeros_like(t, dtype=float)
|
| 271 |
+
visits = [5, 20, 35]
|
| 272 |
+
for v in visits:
|
| 273 |
+
trace[v:] += (0.8 ** np.arange(len(t)-v))
|
| 274 |
+
ax.plot(t, trace, color='brown', lw=2)
|
| 275 |
+
ax.set_title(r"Eligibility Trace $z_t(\lambda)$", fontsize=12, fontweight='bold')
|
| 276 |
+
ax.set_xlabel("Time")
|
| 277 |
+
ax.fill_between(t, trace, color='brown', alpha=0.1)
|
| 278 |
+
|
| 279 |
+
def plot_sarsa_backup(ax):
|
| 280 |
+
"""Temporal Difference: SARSA (On-policy) backup."""
|
| 281 |
+
ax.axis('off')
|
| 282 |
+
ax.set_title("SARSA Backup", fontsize=12, fontweight='bold')
|
| 283 |
+
ax.text(0.5, 0.9, "(s,a)", ha='center')
|
| 284 |
+
ax.text(0.5, 0.1, "(s',a')", ha='center')
|
| 285 |
+
ax.annotate("", xy=(0.5, 0.2), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='orange'))
|
| 286 |
+
ax.text(0.6, 0.5, "On-policy", rotation=90)
|
| 287 |
+
|
| 288 |
+
def plot_q_learning_backup(ax):
|
| 289 |
+
"""Temporal Difference: Q-Learning (Off-policy) backup."""
|
| 290 |
+
ax.axis('off')
|
| 291 |
+
ax.set_title("Q-Learning Backup", fontsize=12, fontweight='bold')
|
| 292 |
+
ax.text(0.5, 0.9, "(s,a)", ha='center')
|
| 293 |
+
ax.text(0.5, 0.1, r"$\max_{a'} Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="lightcyan"))
|
| 294 |
+
ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='blue'))
|
| 295 |
+
|
| 296 |
+
def plot_double_q(ax):
|
| 297 |
+
"""Temporal Difference: Double Q-Learning / Double DQN."""
|
| 298 |
+
ax.axis('off')
|
| 299 |
+
ax.set_title("Double Q-Learning", fontsize=12, fontweight='bold')
|
| 300 |
+
ax.text(0.5, 0.8, "Network A", bbox=dict(fc="lightyellow"), ha='center')
|
| 301 |
+
ax.text(0.5, 0.2, "Network B", bbox=dict(fc="lightcyan"), ha='center')
|
| 302 |
+
ax.annotate("Select $a^*$", xy=(0.3, 0.8), xytext=(0.5, 0.85), arrowprops=dict(arrowstyle="->"))
|
| 303 |
+
ax.annotate("Eval $Q(s', a^*)$", xy=(0.7, 0.2), xytext=(0.5, 0.15), arrowprops=dict(arrowstyle="->"))
|
| 304 |
+
|
| 305 |
+
def plot_dueling_dqn(ax):
|
| 306 |
+
"""Temporal Difference: Dueling DQN Architecture."""
|
| 307 |
+
ax.axis('off')
|
| 308 |
+
ax.set_title("Dueling DQN", fontsize=12, fontweight='bold')
|
| 309 |
+
ax.text(0.1, 0.5, "Backbone", bbox=dict(fc="lightgrey"), ha='center', rotation=90)
|
| 310 |
+
ax.text(0.5, 0.7, "V(s)", bbox=dict(fc="lightgreen"), ha='center')
|
| 311 |
+
ax.text(0.5, 0.3, "A(s,a)", bbox=dict(fc="lightblue"), ha='center')
|
| 312 |
+
ax.text(0.9, 0.5, "Q(s,a)", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
|
| 313 |
+
ax.annotate("", xy=(0.35, 0.7), xytext=(0.15, 0.55), arrowprops=dict(arrowstyle="->"))
|
| 314 |
+
ax.annotate("", xy=(0.35, 0.3), xytext=(0.15, 0.45), arrowprops=dict(arrowstyle="->"))
|
| 315 |
+
ax.annotate("", xy=(0.75, 0.55), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
|
| 316 |
+
ax.annotate("", xy=(0.75, 0.45), xytext=(0.6, 0.3), arrowprops=dict(arrowstyle="->"))
|
| 317 |
+
|
| 318 |
+
def plot_prioritized_replay(ax):
|
| 319 |
+
"""Temporal Difference: Prioritized Experience Replay (PER)."""
|
| 320 |
+
priorities = np.random.pareto(3, 100)
|
| 321 |
+
ax.hist(priorities, bins=20, color='teal', alpha=0.7)
|
| 322 |
+
ax.set_title("Prioritized Replay (TD-Error)", fontsize=12, fontweight='bold')
|
| 323 |
+
ax.set_xlabel("Priority $P_i$")
|
| 324 |
+
ax.set_ylabel("Count")
|
| 325 |
+
|
| 326 |
+
def plot_rainbow_dqn(ax):
|
| 327 |
+
"""Temporal Difference: Rainbow DQN Composite."""
|
| 328 |
+
ax.axis('off')
|
| 329 |
+
ax.set_title("Rainbow DQN", fontsize=12, fontweight='bold')
|
| 330 |
+
features = ["Double", "Dueling", "PER", "Noisy", "Distributional", "n-step"]
|
| 331 |
+
for i, f in enumerate(features):
|
| 332 |
+
ax.text(0.5, 0.9 - i*0.15, f, ha='center', bbox=dict(boxstyle="round", fc="ghostwhite"), fontsize=8)
|
| 333 |
+
|
| 334 |
+
def plot_linear_fa(ax):
|
| 335 |
+
"""Function Approximation: Linear Function Approximation."""
|
| 336 |
+
ax.axis('off')
|
| 337 |
+
ax.set_title("Linear Function Approx", fontsize=12, fontweight='bold')
|
| 338 |
+
ax.text(0.5, 0.8, r"$\phi(s)$ Features", ha='center', bbox=dict(fc="white"))
|
| 339 |
+
ax.text(0.5, 0.2, r"$w^T \phi(s)$", ha='center', bbox=dict(fc="lightgrey"))
|
| 340 |
+
ax.annotate("", xy=(0.5, 0.35), xytext=(0.5, 0.65), arrowprops=dict(arrowstyle="->", lw=2))
|
| 341 |
+
|
| 342 |
+
def plot_nn_layers(ax):
|
| 343 |
+
"""Function Approximation: Neural Network Layers diagram."""
|
| 344 |
+
ax.axis('off')
|
| 345 |
+
ax.set_title("NN Layers (Deep RL)", fontsize=12, fontweight='bold')
|
| 346 |
+
layers = [4, 8, 8, 2]
|
| 347 |
+
for i, l in enumerate(layers):
|
| 348 |
+
for j in range(l):
|
| 349 |
+
ax.scatter(i*0.3, j*0.1 - l*0.05, s=20, c='black')
|
| 350 |
+
ax.set_xlim(-0.1, 1.0)
|
| 351 |
+
ax.set_ylim(-0.5, 0.5)
|
| 352 |
+
|
| 353 |
+
def plot_computation_graph(ax):
|
| 354 |
+
"""Function Approximation: Computation Graph / Backprop Flow."""
|
| 355 |
+
ax.axis('off')
|
| 356 |
+
ax.set_title("Computation Graph (DAG)", fontsize=12, fontweight='bold')
|
| 357 |
+
ax.text(0.1, 0.5, "Input", bbox=dict(boxstyle="circle", fc="white"))
|
| 358 |
+
ax.text(0.5, 0.5, "Op", bbox=dict(boxstyle="square", fc="lightgrey"))
|
| 359 |
+
ax.text(0.9, 0.5, "Loss", bbox=dict(boxstyle="circle", fc="salmon"))
|
| 360 |
+
ax.annotate("", xy=(0.35, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 361 |
+
ax.annotate("", xy=(0.75, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 362 |
+
ax.annotate("Grad", xy=(0.1, 0.3), xytext=(0.9, 0.3), arrowprops=dict(arrowstyle="->", color='red', connectionstyle="arc3,rad=0.2"))
|
| 363 |
+
|
| 364 |
+
def plot_target_network(ax):
|
| 365 |
+
"""Function Approximation: Target Network concept."""
|
| 366 |
+
ax.axis('off')
|
| 367 |
+
ax.set_title("Target Network Updates", fontsize=12, fontweight='bold')
|
| 368 |
+
ax.text(0.3, 0.8, r"$Q_\theta$ (Active)", bbox=dict(fc="lightgreen"))
|
| 369 |
+
ax.text(0.7, 0.8, r"$Q_{\theta^-}$ (Target)", bbox=dict(fc="lightblue"))
|
| 370 |
+
ax.annotate("periodic copy", xy=(0.6, 0.8), xytext=(0.4, 0.8), arrowprops=dict(arrowstyle="<-", ls='--'))
|
| 371 |
+
|
| 372 |
+
def plot_ppo_clip(ax):
|
| 373 |
+
"""Policy Gradients: PPO Clipped Surrogate Objective."""
|
| 374 |
+
epsilon = 0.2
|
| 375 |
+
r = np.linspace(0.5, 1.5, 100)
|
| 376 |
+
advantage = 1.0
|
| 377 |
+
surr1 = r * advantage
|
| 378 |
+
surr2 = np.clip(r, 1-epsilon, 1+epsilon) * advantage
|
| 379 |
+
ax.plot(r, surr1, '--', label="r*A")
|
| 380 |
+
ax.plot(r, np.minimum(surr1, surr2), 'r', label="min(r*A, clip*A)")
|
| 381 |
+
ax.set_title("PPO-Clip Objective", fontsize=12, fontweight='bold')
|
| 382 |
+
ax.legend(fontsize=8)
|
| 383 |
+
ax.axvline(1, color='gray', linestyle=':')
|
| 384 |
+
|
| 385 |
+
def plot_trpo_trust_region(ax):
|
| 386 |
+
"""Policy Gradients: TRPO Trust Region / KL Constraint."""
|
| 387 |
+
ax.set_title("TRPO Trust Region", fontsize=12, fontweight='bold')
|
| 388 |
+
circle = plt.Circle((0.5, 0.5), 0.3, color='blue', fill=False, label="KL Constraint")
|
| 389 |
+
ax.add_artist(circle)
|
| 390 |
+
ax.scatter(0.5, 0.5, c='black', label=r"$\pi_{old}$")
|
| 391 |
+
ax.arrow(0.5, 0.5, 0.15, 0.1, head_width=0.03, color='red', label="Update")
|
| 392 |
+
ax.set_xlim(0, 1); ax.set_ylim(0, 1)
|
| 393 |
+
ax.axis('off')
|
| 394 |
+
|
| 395 |
+
def plot_a3c_multi_worker(ax):
|
| 396 |
+
"""Actor-Critic: Asynchronous Multi-worker (A3C)."""
|
| 397 |
+
ax.axis('off')
|
| 398 |
+
ax.set_title("A3C Multi-worker", fontsize=12, fontweight='bold')
|
| 399 |
+
ax.text(0.5, 0.8, "Global Parameters", bbox=dict(fc="gold"), ha='center')
|
| 400 |
+
for i in range(3):
|
| 401 |
+
ax.text(0.2 + i*0.3, 0.2, f"Worker {i+1}", bbox=dict(fc="lightgrey"), ha='center', fontsize=8)
|
| 402 |
+
ax.annotate("", xy=(0.5, 0.7), xytext=(0.2 + i*0.3, 0.3), arrowprops=dict(arrowstyle="<->"))
|
| 403 |
+
|
| 404 |
+
def plot_sac_arch(ax):
|
| 405 |
+
"""Actor-Critic: SAC (Entropy-regularized)."""
|
| 406 |
+
ax.axis('off')
|
| 407 |
+
ax.set_title("SAC Architecture", fontsize=12, fontweight='bold')
|
| 408 |
+
ax.text(0.5, 0.7, "Actor", bbox=dict(fc="lightgreen"), ha='center')
|
| 409 |
+
ax.text(0.5, 0.3, "Entropy Bonus", bbox=dict(fc="salmon"), ha='center')
|
| 410 |
+
ax.text(0.1, 0.5, "State", ha='center')
|
| 411 |
+
ax.text(0.9, 0.5, "Action", ha='center')
|
| 412 |
+
ax.annotate("", xy=(0.4, 0.7), xytext=(0.15, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 413 |
+
ax.annotate("", xy=(0.5, 0.55), xytext=(0.5, 0.4), arrowprops=dict(arrowstyle="->"))
|
| 414 |
+
ax.annotate("", xy=(0.85, 0.5), xytext=(0.6, 0.7), arrowprops=dict(arrowstyle="->"))
|
| 415 |
+
|
| 416 |
+
def plot_softmax_exploration(ax):
|
| 417 |
+
"""Exploration: Softmax / Boltzmann probabilities."""
|
| 418 |
+
x = np.arange(4)
|
| 419 |
+
logits = [1, 2, 5, 3]
|
| 420 |
+
for tau in [0.5, 1.0, 5.0]:
|
| 421 |
+
probs = np.exp(np.array(logits)/tau)
|
| 422 |
+
probs /= probs.sum()
|
| 423 |
+
ax.plot(x, probs, marker='o', label=rf"$\tau={tau}$")
|
| 424 |
+
ax.set_title("Softmax Exploration", fontsize=12, fontweight='bold')
|
| 425 |
+
ax.legend(fontsize=8)
|
| 426 |
+
ax.set_xticks(x)
|
| 427 |
+
|
| 428 |
+
def plot_ucb_confidence(ax):
|
| 429 |
+
"""Exploration: Upper Confidence Bound (UCB)."""
|
| 430 |
+
actions = ['A1', 'A2', 'A3']
|
| 431 |
+
means = [0.6, 0.8, 0.5]
|
| 432 |
+
conf = [0.3, 0.1, 0.4]
|
| 433 |
+
ax.bar(actions, means, yerr=conf, capsize=10, color='skyblue', label='Mean Q')
|
| 434 |
+
ax.set_title("UCB Action Values", fontsize=12, fontweight='bold')
|
| 435 |
+
ax.set_ylim(0, 1.2)
|
| 436 |
+
|
| 437 |
+
def plot_intrinsic_motivation(ax):
|
| 438 |
+
"""Exploration: Intrinsic Motivation / Curiosity."""
|
| 439 |
+
ax.axis('off')
|
| 440 |
+
ax.set_title("Intrinsic Motivation", fontsize=12, fontweight='bold')
|
| 441 |
+
ax.text(0.3, 0.5, "World Model", bbox=dict(fc="lightyellow"), ha='center')
|
| 442 |
+
ax.text(0.7, 0.5, "Prediction\nError", bbox=dict(boxstyle="circle", fc="orange"), ha='center')
|
| 443 |
+
ax.annotate("", xy=(0.58, 0.5), xytext=(0.42, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 444 |
+
ax.text(0.85, 0.5, r"$R_{int}$", fontweight='bold')
|
| 445 |
+
|
| 446 |
+
def plot_entropy_bonus(ax):
|
| 447 |
+
"""Exploration: Entropy Regularization curve."""
|
| 448 |
+
p = np.linspace(0.01, 0.99, 50)
|
| 449 |
+
entropy = -(p * np.log(p) + (1-p) * np.log(1-p))
|
| 450 |
+
ax.plot(p, entropy, color='purple')
|
| 451 |
+
ax.set_title(r"Entropy $H(\pi)$", fontsize=12, fontweight='bold')
|
| 452 |
+
ax.set_xlabel("$P(a)$")
|
| 453 |
+
|
| 454 |
+
def plot_options_framework(ax):
|
| 455 |
+
"""Hierarchical RL: Options Framework."""
|
| 456 |
+
ax.axis('off')
|
| 457 |
+
ax.set_title("Options Framework", fontsize=12, fontweight='bold')
|
| 458 |
+
ax.text(0.5, 0.8, r"High-level policy" + "\n" + r"$\pi_{hi}$", bbox=dict(fc="lightblue"), ha='center')
|
| 459 |
+
ax.text(0.2, 0.2, "Option 1", bbox=dict(fc="ivory"), ha='center')
|
| 460 |
+
ax.text(0.8, 0.2, "Option 2", bbox=dict(fc="ivory"), ha='center')
|
| 461 |
+
ax.annotate("", xy=(0.3, 0.3), xytext=(0.45, 0.7), arrowprops=dict(arrowstyle="->"))
|
| 462 |
+
ax.annotate("", xy=(0.7, 0.3), xytext=(0.55, 0.7), arrowprops=dict(arrowstyle="->"))
|
| 463 |
+
|
| 464 |
+
def plot_feudal_networks(ax):
|
| 465 |
+
"""Hierarchical RL: Feudal Networks / Hierarchy."""
|
| 466 |
+
ax.axis('off')
|
| 467 |
+
ax.set_title("Feudal Networks", fontsize=12, fontweight='bold')
|
| 468 |
+
ax.text(0.5, 0.85, "Manager", bbox=dict(fc="plum"), ha='center')
|
| 469 |
+
ax.text(0.5, 0.15, "Worker", bbox=dict(fc="wheat"), ha='center')
|
| 470 |
+
ax.annotate("Goal $g_t$", xy=(0.5, 0.3), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
|
| 471 |
+
|
| 472 |
+
def plot_world_model(ax):
|
| 473 |
+
"""Model-Based RL: Learned Dynamics Model."""
|
| 474 |
+
ax.axis('off')
|
| 475 |
+
ax.set_title("World Model (Dynamics)", fontsize=12, fontweight='bold')
|
| 476 |
+
ax.text(0.1, 0.5, "(s,a)", ha='center')
|
| 477 |
+
ax.text(0.5, 0.5, r"$\hat{P}$", bbox=dict(boxstyle="circle", fc="lightgrey"), ha='center')
|
| 478 |
+
ax.text(0.9, 0.7, r"$\hat{s}'$", ha='center')
|
| 479 |
+
ax.text(0.9, 0.3, r"$\hat{r}$", ha='center')
|
| 480 |
+
ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 481 |
+
ax.annotate("", xy=(0.8, 0.65), xytext=(0.6, 0.55), arrowprops=dict(arrowstyle="->"))
|
| 482 |
+
ax.annotate("", xy=(0.8, 0.35), xytext=(0.6, 0.45), arrowprops=dict(arrowstyle="->"))
|
| 483 |
+
|
| 484 |
+
def plot_model_planning(ax):
|
| 485 |
+
"""Model-Based RL: Planning / Rollouts in imagination."""
|
| 486 |
+
ax.axis('off')
|
| 487 |
+
ax.set_title("Model-Based Planning", fontsize=12, fontweight='bold')
|
| 488 |
+
ax.text(0.1, 0.5, "Real s", ha='center', fontweight='bold')
|
| 489 |
+
for i in range(3):
|
| 490 |
+
ax.annotate("", xy=(0.3+i*0.2, 0.5+(i%2)*0.1), xytext=(0.1+i*0.2, 0.5), arrowprops=dict(arrowstyle="->", color='gray'))
|
| 491 |
+
ax.text(0.3+i*0.2, 0.55+(i%2)*0.1, "imagined", fontsize=7)
|
| 492 |
+
|
| 493 |
+
def plot_offline_rl(ax):
|
| 494 |
+
"""Offline RL: Fixed dataset of trajectories."""
|
| 495 |
+
ax.axis('off')
|
| 496 |
+
ax.set_title("Offline RL Dataset", fontsize=12, fontweight='bold')
|
| 497 |
+
ax.text(0.5, 0.5, r"Static" + "\n" + r"Dataset" + "\n" + r"$\mathcal{D}$", bbox=dict(boxstyle="round", fc="lightgrey"), ha='center')
|
| 498 |
+
ax.annotate("No interaction", xy=(0.5, 0.9), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", color='red'))
|
| 499 |
+
ax.scatter([0.2, 0.8, 0.3, 0.7], [0.8, 0.8, 0.2, 0.2], marker='x', color='blue')
|
| 500 |
+
|
| 501 |
+
def plot_cql_regularization(ax):
|
| 502 |
+
"""Offline RL: CQL regularization visualization."""
|
| 503 |
+
q = np.linspace(-5, 5, 100)
|
| 504 |
+
penalty = q**2 * 0.1
|
| 505 |
+
ax.plot(q, penalty, 'r', label='CQL Penalty')
|
| 506 |
+
ax.set_title("CQL Regularization", fontsize=12, fontweight='bold')
|
| 507 |
+
ax.set_xlabel("Q-value")
|
| 508 |
+
ax.legend(fontsize=8)
|
| 509 |
+
|
| 510 |
+
def plot_multi_agent_interaction(ax):
|
| 511 |
+
"""Multi-Agent RL: Agents communicating or competing."""
|
| 512 |
+
G = nx.complete_graph(3)
|
| 513 |
+
pos = nx.spring_layout(G)
|
| 514 |
+
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=500, node_color=['red', 'blue', 'green'])
|
| 515 |
+
nx.draw_networkx_edges(ax=ax, G=G, pos=pos, style='dashed')
|
| 516 |
+
ax.set_title("Multi-Agent Interaction", fontsize=12, fontweight='bold')
|
| 517 |
+
ax.axis('off')
|
| 518 |
+
|
| 519 |
+
def plot_ctde(ax):
|
| 520 |
+
"""Multi-Agent RL: Centralized Training Decentralized Execution (CTDE)."""
|
| 521 |
+
ax.axis('off')
|
| 522 |
+
ax.set_title("CTDE Architecture", fontsize=12, fontweight='bold')
|
| 523 |
+
ax.text(0.5, 0.8, "Centralized Critic", bbox=dict(fc="gold"), ha='center')
|
| 524 |
+
ax.text(0.2, 0.2, "Agent 1", bbox=dict(fc="lightblue"), ha='center')
|
| 525 |
+
ax.text(0.8, 0.2, "Agent 2", bbox=dict(fc="lightblue"), ha='center')
|
| 526 |
+
ax.annotate("", xy=(0.5, 0.7), xytext=(0.25, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
|
| 527 |
+
ax.annotate("", xy=(0.5, 0.7), xytext=(0.75, 0.35), arrowprops=dict(arrowstyle="<-", color='gray'))
|
| 528 |
+
|
| 529 |
+
def plot_payoff_matrix(ax):
|
| 530 |
+
"""Multi-Agent RL: Cooperative / Competitive Payoff Matrix."""
|
| 531 |
+
matrix = np.array([[(3,3), (0,5)], [(5,0), (1,1)]])
|
| 532 |
+
ax.axis('off')
|
| 533 |
+
ax.set_title("Payoff Matrix (Prisoner's)", fontsize=12, fontweight='bold')
|
| 534 |
+
for i in range(2):
|
| 535 |
+
for j in range(2):
|
| 536 |
+
ax.text(j, 1-i, str(matrix[i, j]), ha='center', va='center', bbox=dict(fc="white"))
|
| 537 |
+
ax.set_xlim(-0.5, 1.5); ax.set_ylim(-0.5, 1.5)
|
| 538 |
+
|
| 539 |
+
def plot_irl_reward_inference(ax):
|
| 540 |
+
"""Inverse RL: Infer reward from expert demonstrations."""
|
| 541 |
+
ax.axis('off')
|
| 542 |
+
ax.set_title("Inferred Reward Heatmap", fontsize=12, fontweight='bold')
|
| 543 |
+
grid = np.zeros((5, 5))
|
| 544 |
+
grid[2:4, 2:4] = 1.0 # Expert path
|
| 545 |
+
ax.imshow(grid, cmap='hot')
|
| 546 |
+
|
| 547 |
+
def plot_gail_flow(ax):
|
| 548 |
+
"""Inverse RL: GAIL (Generative Adversarial Imitation Learning)."""
|
| 549 |
+
ax.axis('off')
|
| 550 |
+
ax.set_title("GAIL Architecture", fontsize=12, fontweight='bold')
|
| 551 |
+
ax.text(0.2, 0.8, "Expert Data", bbox=dict(fc="lightgrey"), ha='center')
|
| 552 |
+
ax.text(0.2, 0.2, "Policy (Gen)", bbox=dict(fc="lightgreen"), ha='center')
|
| 553 |
+
ax.text(0.8, 0.5, "Discriminator", bbox=dict(boxstyle="square", fc="salmon"), ha='center')
|
| 554 |
+
ax.annotate("", xy=(0.6, 0.55), xytext=(0.35, 0.75), arrowprops=dict(arrowstyle="->"))
|
| 555 |
+
ax.annotate("", xy=(0.6, 0.45), xytext=(0.35, 0.25), arrowprops=dict(arrowstyle="->"))
|
| 556 |
+
|
| 557 |
+
def plot_meta_rl_nested_loop(ax):
|
| 558 |
+
"""Meta-RL: Outer loop (meta) + inner loop (adaptation)."""
|
| 559 |
+
ax.axis('off')
|
| 560 |
+
ax.set_title("Meta-RL Loops", fontsize=12, fontweight='bold')
|
| 561 |
+
ax.add_patch(plt.Circle((0.5, 0.5), 0.4, fill=False, ls='--'))
|
| 562 |
+
ax.add_patch(plt.Circle((0.5, 0.5), 0.2, fill=False))
|
| 563 |
+
ax.text(0.5, 0.5, "Inner\nLoop", ha='center', fontsize=8)
|
| 564 |
+
ax.text(0.5, 0.8, "Outer Loop", ha='center', fontsize=10)
|
| 565 |
+
|
| 566 |
+
def plot_task_distribution(ax):
|
| 567 |
+
"""Meta-RL: Multiple MDPs from distribution."""
|
| 568 |
+
ax.axis('off')
|
| 569 |
+
ax.set_title("Task Distribution", fontsize=12, fontweight='bold')
|
| 570 |
+
for i in range(3):
|
| 571 |
+
ax.text(0.2 + i*0.3, 0.5, f"Task {i+1}", bbox=dict(boxstyle="round", fc="ivory"), fontsize=8)
|
| 572 |
+
ax.annotate("sample", xy=(0.5, 0.8), xytext=(0.5, 0.6), arrowprops=dict(arrowstyle="<-"))
|
| 573 |
+
|
| 574 |
+
def plot_replay_buffer(ax):
|
| 575 |
+
"""Advanced: Experience Replay Buffer (FIFO)."""
|
| 576 |
+
ax.axis('off')
|
| 577 |
+
ax.set_title("Experience Replay Buffer", fontsize=12, fontweight='bold')
|
| 578 |
+
for i in range(5):
|
| 579 |
+
ax.add_patch(plt.Rectangle((0.1+i*0.15, 0.4), 0.1, 0.2, fill=True, color='lightgrey'))
|
| 580 |
+
ax.text(0.15+i*0.15, 0.5, f"e_{i}", ha='center')
|
| 581 |
+
ax.annotate("In", xy=(0.05, 0.5), xytext=(-0.1, 0.5), arrowprops=dict(arrowstyle="->"), annotation_clip=False)
|
| 582 |
+
ax.annotate("Out (Batch)", xy=(0.85, 0.5), xytext=(1.0, 0.5), arrowprops=dict(arrowstyle="<-"), annotation_clip=False)
|
| 583 |
+
|
| 584 |
+
def plot_state_visitation(ax):
|
| 585 |
+
"""Advanced: State Visitation / Occupancy Measure."""
|
| 586 |
+
data = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], 1000)
|
| 587 |
+
ax.hexbin(data[:, 0], data[:, 1], gridsize=15, cmap='Blues')
|
| 588 |
+
ax.set_title("State Visitation Heatmap", fontsize=12, fontweight='bold')
|
| 589 |
+
|
| 590 |
+
def plot_regret_curve(ax):
|
| 591 |
+
"""Advanced: Regret / Cumulative Regret."""
|
| 592 |
+
t = np.arange(100)
|
| 593 |
+
regret = np.sqrt(t) + np.random.normal(0, 0.5, 100)
|
| 594 |
+
ax.plot(t, regret, color='red', label='Sub-linear Regret')
|
| 595 |
+
ax.set_title("Cumulative Regret", fontsize=12, fontweight='bold')
|
| 596 |
+
ax.set_xlabel("Time")
|
| 597 |
+
ax.legend(fontsize=8)
|
| 598 |
+
|
| 599 |
+
def plot_attention_weights(ax):
|
| 600 |
+
"""Advanced: Attention Mechanisms (Heatmap)."""
|
| 601 |
+
weights = np.random.rand(5, 5)
|
| 602 |
+
ax.imshow(weights, cmap='viridis')
|
| 603 |
+
ax.set_title("Attention Weight Matrix", fontsize=12, fontweight='bold')
|
| 604 |
+
ax.set_xticks([]); ax.set_yticks([])
|
| 605 |
+
|
| 606 |
+
def plot_diffusion_policy(ax):
|
| 607 |
+
"""Advanced: Diffusion Policy denoising steps."""
|
| 608 |
+
ax.axis('off')
|
| 609 |
+
ax.set_title("Diffusion Policy (Denoising)", fontsize=12, fontweight='bold')
|
| 610 |
+
for i in range(4):
|
| 611 |
+
ax.scatter(0.1+i*0.25, 0.5, s=100/(i+1), c='black', alpha=1.0 - i*0.2)
|
| 612 |
+
if i < 3: ax.annotate("", xy=(0.25+i*0.25, 0.5), xytext=(0.15+i*0.25, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 613 |
+
ax.text(0.5, 0.3, "Noise $\\rightarrow$ Action", ha='center', fontsize=8)
|
| 614 |
+
|
| 615 |
+
def plot_gnn_rl(ax):
|
| 616 |
+
"""Advanced: Graph Neural Networks for RL."""
|
| 617 |
+
G = nx.star_graph(4)
|
| 618 |
+
pos = nx.spring_layout(G)
|
| 619 |
+
nx.draw_networkx_nodes(ax=ax, G=G, pos=pos, node_size=200, node_color='orange')
|
| 620 |
+
nx.draw_networkx_edges(ax=ax, G=G, pos=pos)
|
| 621 |
+
ax.set_title("GNN Message Passing", fontsize=12, fontweight='bold')
|
| 622 |
+
ax.axis('off')
|
| 623 |
+
|
| 624 |
+
def plot_latent_space(ax):
|
| 625 |
+
"""Advanced: World Model / Latent Space."""
|
| 626 |
+
ax.axis('off')
|
| 627 |
+
ax.set_title("Latent Space (VAE/Dreamer)", fontsize=12, fontweight='bold')
|
| 628 |
+
ax.text(0.1, 0.5, "Image", bbox=dict(fc="lightgrey"), ha='center')
|
| 629 |
+
ax.text(0.5, 0.5, "Latent $z$", bbox=dict(boxstyle="circle", fc="lightpink"), ha='center')
|
| 630 |
+
ax.text(0.9, 0.5, "Reconstruction", bbox=dict(fc="lightgrey"), ha='center')
|
| 631 |
+
ax.annotate("", xy=(0.4, 0.5), xytext=(0.2, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 632 |
+
ax.annotate("", xy=(0.8, 0.5), xytext=(0.6, 0.5), arrowprops=dict(arrowstyle="->"))
|
| 633 |
+
|
| 634 |
+
def plot_convergence_log(ax):
|
| 635 |
+
"""Advanced: Convergence Analysis Plots (Log-scale)."""
|
| 636 |
+
iterations = np.arange(1, 100)
|
| 637 |
+
error = 10 / iterations**2
|
| 638 |
+
ax.loglog(iterations, error, color='green')
|
| 639 |
+
ax.set_title("Value Convergence (Log)", fontsize=12, fontweight='bold')
|
| 640 |
+
ax.set_xlabel("Iterations")
|
| 641 |
+
ax.set_ylabel("Error")
|
| 642 |
+
ax.grid(True, which="both", ls="-", alpha=0.3)
|
| 643 |
+
|
| 644 |
+
def plot_expected_sarsa_backup(ax):
|
| 645 |
+
"""Temporal Difference: Expected SARSA (Expectation over policy)."""
|
| 646 |
+
ax.axis('off')
|
| 647 |
+
ax.set_title("Expected SARSA Backup", fontsize=12, fontweight='bold')
|
| 648 |
+
ax.text(0.5, 0.9, "(s,a)", ha='center')
|
| 649 |
+
ax.text(0.5, 0.1, r"$\sum_{a'} \pi(a'|s') Q(s',a')$", ha='center', bbox=dict(boxstyle="round", fc="ivory"))
|
| 650 |
+
ax.annotate("", xy=(0.5, 0.25), xytext=(0.5, 0.8), arrowprops=dict(arrowstyle="<-", lw=2, color='purple'))
|
| 651 |
+
|
| 652 |
+
def plot_reinforce_flow(ax):
|
| 653 |
+
"""Policy Gradients: REINFORCE (Full trajectory flow)."""
|
| 654 |
+
ax.axis('off')
|
| 655 |
+
ax.set_title("REINFORCE Flow", fontsize=12, fontweight='bold')
|
| 656 |
+
steps = ["s0", "a0", "r1", "s1", "...", "GT"]
|
| 657 |
+
for i, s in enumerate(steps):
|
| 658 |
+
ax.text(0.1 + i*0.15, 0.5, s, bbox=dict(boxstyle="circle", fc="white"))
|
| 659 |
+
ax.annotate(r"$\nabla_\theta J \propto G_t \nabla \ln \pi$", xy=(0.5, 0.8), ha='center', fontsize=12, color='darkgreen')
|
| 660 |
+
|
| 661 |
+
def plot_advantage_scaled_grad(ax):
|
| 662 |
+
"""Policy Gradients: Baseline / Advantage scaled gradient."""
|
| 663 |
+
ax.axis('off')
|
| 664 |
+
ax.set_title("Baseline Subtraction", fontsize=12, fontweight='bold')
|
| 665 |
+
ax.text(0.5, 0.8, r"$(G_t - b(s))$", bbox=dict(fc="salmon"), ha='center')
|
| 666 |
+
ax.text(0.5, 0.3, r"Scale $\nabla \ln \pi$", ha='center')
|
| 667 |
+
ax.annotate("", xy=(0.5, 0.4), xytext=(0.5, 0.7), arrowprops=dict(arrowstyle="->"))
|
| 668 |
+
|
| 669 |
+
def plot_skill_discovery(ax):
|
| 670 |
+
"""Hierarchical RL: Skill Discovery (Unsupervised clusters)."""
|
| 671 |
+
np.random.seed(0)
|
| 672 |
+
for i in range(3):
|
| 673 |
+
center = np.random.randn(2) * 2
|
| 674 |
+
pts = np.random.randn(20, 2) * 0.5 + center
|
| 675 |
+
ax.scatter(pts[:, 0], pts[:, 1], alpha=0.6, label=f"Skill {i+1}")
|
| 676 |
+
ax.set_title("Skill Embedding Space", fontsize=12, fontweight='bold')
|
| 677 |
+
ax.legend(fontsize=8)
|
| 678 |
+
|
| 679 |
+
def plot_imagination_rollout(ax):
|
| 680 |
+
"""Model-Based RL: Imagination-Augmented Rollouts (I2A)."""
|
| 681 |
+
ax.axis('off')
|
| 682 |
+
ax.set_title("Imagination Rollout (I2A)", fontsize=12, fontweight='bold')
|
| 683 |
+
ax.text(0.1, 0.5, "Input s", ha='center')
|
| 684 |
+
ax.add_patch(plt.Rectangle((0.3, 0.3), 0.4, 0.4, fill=True, color='lavender'))
|
| 685 |
+
ax.text(0.5, 0.5, "Imagination\nModule", ha='center')
|
| 686 |
+
ax.annotate("Imagined Paths", xy=(0.8, 0.5), xytext=(0.5, 0.5), arrowprops=dict(arrowstyle="->", color='gray', connectionstyle="arc3,rad=0.3"))
|
| 687 |
+
|
| 688 |
+
def plot_policy_gradient_flow(ax):
|
| 689 |
+
"""Policy Gradients: Gradient flow from reward to log-prob (DAG)."""
|
| 690 |
+
ax.axis('off')
|
| 691 |
+
ax.set_title("Policy Gradient Flow (DAG)", fontsize=12, fontweight='bold')
|
| 692 |
+
|
| 693 |
+
bbox_props = dict(boxstyle="round,pad=0.5", fc="lightgrey", ec="black", lw=1.5)
|
| 694 |
+
ax.text(0.1, 0.8, r"Trajectory $\tau$", ha="center", va="center", bbox=bbox_props)
|
| 695 |
+
ax.text(0.5, 0.8, r"Reward $R(\tau)$", ha="center", va="center", bbox=bbox_props)
|
| 696 |
+
ax.text(0.1, 0.2, r"Log-Prob $\log \pi_\theta$", ha="center", va="center", bbox=bbox_props)
|
| 697 |
+
ax.text(0.7, 0.5, r"$\nabla_\theta J(\theta)$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="gold", ec="black"))
|
| 698 |
+
|
| 699 |
+
# Draw arrows
|
| 700 |
+
ax.annotate("", xy=(0.35, 0.8), xytext=(0.2, 0.8), arrowprops=dict(arrowstyle="->", lw=2))
|
| 701 |
+
ax.annotate("", xy=(0.7, 0.65), xytext=(0.5, 0.75), arrowprops=dict(arrowstyle="->", lw=2))
|
| 702 |
+
ax.annotate("", xy=(0.6, 0.4), xytext=(0.25, 0.2), arrowprops=dict(arrowstyle="->", lw=2))
|
| 703 |
+
|
| 704 |
+
def plot_actor_critic_arch(ax):
|
| 705 |
+
"""Actor-Critic: Three-network diagram (TD3 - actor + two critics)."""
|
| 706 |
+
ax.axis('off')
|
| 707 |
+
ax.set_title("TD3 Architecture Diagram", fontsize=12, fontweight='bold')
|
| 708 |
+
|
| 709 |
+
# State input
|
| 710 |
+
ax.text(0.1, 0.5, r"State" + "\n" + r"$s$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.5", fc="lightblue"))
|
| 711 |
+
|
| 712 |
+
# Networks
|
| 713 |
+
net_props = dict(boxstyle="square,pad=0.8", fc="lightgreen", ec="black")
|
| 714 |
+
ax.text(0.5, 0.8, r"Actor $\pi_\phi$", ha="center", va="center", bbox=net_props)
|
| 715 |
+
ax.text(0.5, 0.5, r"Critic 1 $Q_{\theta_1}$", ha="center", va="center", bbox=net_props)
|
| 716 |
+
ax.text(0.5, 0.2, r"Critic 2 $Q_{\theta_2}$", ha="center", va="center", bbox=net_props)
|
| 717 |
+
|
| 718 |
+
# Outputs
|
| 719 |
+
ax.text(0.8, 0.8, "Action $a$", ha="center", va="center", bbox=dict(boxstyle="circle,pad=0.3", fc="coral"))
|
| 720 |
+
ax.text(0.8, 0.35, "Min Q-value", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.3", fc="gold"))
|
| 721 |
+
|
| 722 |
+
# Connections
|
| 723 |
+
kwargs = dict(arrowstyle="->", lw=1.5)
|
| 724 |
+
ax.annotate("", xy=(0.38, 0.8), xytext=(0.15, 0.55), arrowprops=kwargs) # S -> Actor
|
| 725 |
+
ax.annotate("", xy=(0.38, 0.5), xytext=(0.15, 0.5), arrowprops=kwargs) # S -> C1
|
| 726 |
+
ax.annotate("", xy=(0.38, 0.2), xytext=(0.15, 0.45), arrowprops=kwargs) # S -> C2
|
| 727 |
+
ax.annotate("", xy=(0.73, 0.8), xytext=(0.62, 0.8), arrowprops=kwargs) # Actor -> Action
|
| 728 |
+
ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.5), arrowprops=kwargs) # C1 -> Min
|
| 729 |
+
ax.annotate("", xy=(0.68, 0.35), xytext=(0.62, 0.2), arrowprops=kwargs) # C2 -> Min
|
| 730 |
+
|
| 731 |
+
def plot_epsilon_decay(ax):
|
| 732 |
+
"""Exploration: ε-Greedy Strategy Decay Curve."""
|
| 733 |
+
episodes = np.arange(0, 1000)
|
| 734 |
+
epsilon = np.maximum(0.01, np.exp(-0.005 * episodes)) # Exponential decay
|
| 735 |
+
|
| 736 |
+
ax.plot(episodes, epsilon, color='purple', lw=2)
|
| 737 |
+
ax.set_title(r"$\epsilon$-Greedy Decay Curve", fontsize=12, fontweight='bold')
|
| 738 |
+
ax.set_xlabel("Episodes")
|
| 739 |
+
ax.set_ylabel(r"Probability $\epsilon$")
|
| 740 |
+
ax.grid(True, linestyle='--', alpha=0.6)
|
| 741 |
+
ax.fill_between(episodes, epsilon, color='purple', alpha=0.1)
|
| 742 |
+
|
| 743 |
+
def plot_learning_curve(ax):
|
| 744 |
+
"""Advanced / Misc: Learning Curve with Confidence Bands."""
|
| 745 |
+
steps = np.linspace(0, 1e6, 100)
|
| 746 |
+
# Simulate a learning curve converging to a maximum
|
| 747 |
+
mean_return = 100 * (1 - np.exp(-5e-6 * steps)) + np.random.normal(0, 2, len(steps))
|
| 748 |
+
std_dev = 15 * np.exp(-2e-6 * steps) # Variance decreases as policy stabilizes
|
| 749 |
+
|
| 750 |
+
ax.plot(steps, mean_return, color='blue', lw=2, label="PPO (Mean)")
|
| 751 |
+
ax.fill_between(steps, mean_return - std_dev, mean_return + std_dev, color='blue', alpha=0.2, label="±1 Std Dev")
|
| 752 |
+
|
| 753 |
+
ax.set_title("Learning Curve (Return vs Steps)", fontsize=12, fontweight='bold')
|
| 754 |
+
ax.set_xlabel("Environment Steps")
|
| 755 |
+
ax.set_ylabel("Average Episodic Return")
|
| 756 |
+
ax.legend(loc="lower right")
|
| 757 |
+
ax.grid(True, linestyle='--', alpha=0.6)
|
| 758 |
+
|
| 759 |
+
def main():
|
| 760 |
+
# Figure 1: MDP & Environment (7 plots)
|
| 761 |
+
fig1, gs1 = setup_figure("RL: MDP & Environment", 2, 4)
|
| 762 |
+
|
| 763 |
+
plot_agent_env_loop(fig1.add_subplot(gs1[0, 0]))
|
| 764 |
+
plot_mdp_graph(fig1.add_subplot(gs1[0, 1]))
|
| 765 |
+
plot_trajectory(fig1.add_subplot(gs1[0, 2]))
|
| 766 |
+
plot_continuous_space(fig1.add_subplot(gs1[0, 3]))
|
| 767 |
+
plot_reward_landscape(fig1, gs1) # projection='3d' handled inside
|
| 768 |
+
plot_discount_decay(fig1.add_subplot(gs1[1, 1]))
|
| 769 |
+
# row 5 (State Transition Graph) is basically plot_mdp_graph
|
| 770 |
+
|
| 771 |
+
# Layout handled by constrained_layout=True
|
| 772 |
+
|
| 773 |
+
# Figure 2: Value, Policy & Dynamic Programming
|
| 774 |
+
fig2, gs2 = setup_figure("RL: Value, Policy & Dynamic Programming", 2, 4)
|
| 775 |
+
plot_value_heatmap(fig2.add_subplot(gs2[0, 0]))
|
| 776 |
+
plot_action_value_q(fig2.add_subplot(gs2[0, 1]))
|
| 777 |
+
plot_policy_arrows(fig2.add_subplot(gs2[0, 2]))
|
| 778 |
+
plot_advantage_function(fig2.add_subplot(gs2[0, 3]))
|
| 779 |
+
plot_backup_diagram(fig2.add_subplot(gs2[1, 0])) # Policy Eval
|
| 780 |
+
plot_policy_improvement(fig2.add_subplot(gs2[1, 1]))
|
| 781 |
+
plot_value_iteration_backup(fig2.add_subplot(gs2[1, 2]))
|
| 782 |
+
plot_policy_iteration_cycle(fig2.add_subplot(gs2[1, 3]))
|
| 783 |
+
|
| 784 |
+
# Layout handled by constrained_layout=True
|
| 785 |
+
|
| 786 |
+
# Figure 3: Monte Carlo & Temporal Difference
|
| 787 |
+
fig3, gs3 = setup_figure("RL: Monte Carlo & Temporal Difference", 2, 4)
|
| 788 |
+
plot_mc_backup(fig3.add_subplot(gs3[0, 0]))
|
| 789 |
+
plot_mcts(fig3.add_subplot(gs3[0, 1]))
|
| 790 |
+
plot_importance_sampling(fig3.add_subplot(gs3[0, 2]))
|
| 791 |
+
plot_td_backup(fig3.add_subplot(gs3[0, 3]))
|
| 792 |
+
plot_nstep_td(fig3.add_subplot(gs3[1, 0]))
|
| 793 |
+
plot_eligibility_traces(fig3.add_subplot(gs3[1, 1]))
|
| 794 |
+
plot_sarsa_backup(fig3.add_subplot(gs3[1, 2]))
|
| 795 |
+
plot_q_learning_backup(fig3.add_subplot(gs3[1, 3]))
|
| 796 |
+
|
| 797 |
+
# Layout handled by constrained_layout=True
|
| 798 |
+
|
| 799 |
+
# Figure 4: TD Extensions & Function Approximation
|
| 800 |
+
fig4, gs4 = setup_figure("RL: TD Extensions & Function Approximation", 2, 4)
|
| 801 |
+
plot_double_q(fig4.add_subplot(gs4[0, 0]))
|
| 802 |
+
plot_dueling_dqn(fig4.add_subplot(gs4[0, 1]))
|
| 803 |
+
plot_prioritized_replay(fig4.add_subplot(gs4[0, 2]))
|
| 804 |
+
plot_rainbow_dqn(fig4.add_subplot(gs4[0, 3]))
|
| 805 |
+
plot_linear_fa(fig4.add_subplot(gs4[1, 0]))
|
| 806 |
+
plot_nn_layers(fig4.add_subplot(gs4[1, 1]))
|
| 807 |
+
plot_computation_graph(fig4.add_subplot(gs4[1, 2]))
|
| 808 |
+
plot_target_network(fig4.add_subplot(gs4[1, 3]))
|
| 809 |
+
|
| 810 |
+
# Layout handled by constrained_layout=True
|
| 811 |
+
|
| 812 |
+
# Figure 5: Policy Gradients, Actor-Critic & Exploration
|
| 813 |
+
fig5, gs5 = setup_figure("RL: Policy Gradients, Actor-Critic & Exploration", 2, 4)
|
| 814 |
+
plot_policy_gradient_flow(fig5.add_subplot(gs5[0, 0]))
|
| 815 |
+
plot_ppo_clip(fig5.add_subplot(gs5[0, 1]))
|
| 816 |
+
plot_trpo_trust_region(fig5.add_subplot(gs5[0, 2]))
|
| 817 |
+
plot_actor_critic_arch(fig5.add_subplot(gs5[0, 3]))
|
| 818 |
+
plot_a3c_multi_worker(fig5.add_subplot(gs5[1, 0]))
|
| 819 |
+
plot_sac_arch(fig5.add_subplot(gs5[1, 1]))
|
| 820 |
+
plot_softmax_exploration(fig5.add_subplot(gs5[1, 2]))
|
| 821 |
+
plot_ucb_confidence(fig5.add_subplot(gs5[1, 3]))
|
| 822 |
+
|
| 823 |
+
# Layout handled by constrained_layout=True
|
| 824 |
+
|
| 825 |
+
# Figure 6: Hierarchical, Model-Based & Offline RL
|
| 826 |
+
fig6, gs6 = setup_figure("RL: Hierarchical, Model-Based & Offline", 2, 4)
|
| 827 |
+
plot_options_framework(fig6.add_subplot(gs6[0, 0]))
|
| 828 |
+
plot_feudal_networks(fig6.add_subplot(gs6[0, 1]))
|
| 829 |
+
plot_world_model(fig6.add_subplot(gs6[0, 2]))
|
| 830 |
+
plot_model_planning(fig6.add_subplot(gs6[0, 3]))
|
| 831 |
+
plot_offline_rl(fig6.add_subplot(gs6[1, 0]))
|
| 832 |
+
plot_cql_regularization(fig6.add_subplot(gs6[1, 1]))
|
| 833 |
+
plot_epsilon_decay(fig6.add_subplot(gs6[1, 2])) # placeholder/spacer
|
| 834 |
+
plot_intrinsic_motivation(fig6.add_subplot(gs6[1, 3]))
|
| 835 |
+
|
| 836 |
+
# Layout handled by constrained_layout=True
|
| 837 |
+
|
| 838 |
+
# Figure 7: Multi-Agent, IRL & Meta-RL
|
| 839 |
+
fig7, gs7 = setup_figure("RL: Multi-Agent, IRL & Meta-RL", 2, 4)
|
| 840 |
+
plot_multi_agent_interaction(fig7.add_subplot(gs7[0, 0]))
|
| 841 |
+
plot_ctde(fig7.add_subplot(gs7[0, 1]))
|
| 842 |
+
plot_payoff_matrix(fig7.add_subplot(gs7[0, 2]))
|
| 843 |
+
plot_irl_reward_inference(fig7.add_subplot(gs7[0, 3]))
|
| 844 |
+
plot_gail_flow(fig7.add_subplot(gs7[1, 0]))
|
| 845 |
+
plot_meta_rl_nested_loop(fig7.add_subplot(gs7[1, 1]))
|
| 846 |
+
plot_task_distribution(fig7.add_subplot(gs7[1, 2]))
|
| 847 |
+
|
| 848 |
+
# Layout handled by constrained_layout=True
|
| 849 |
+
|
| 850 |
+
# Figure 8: Advanced / Miscellaneous Topics
|
| 851 |
+
fig8, gs8 = setup_figure("RL: Advanced & Miscellaneous", 2, 4)
|
| 852 |
+
plot_replay_buffer(fig8.add_subplot(gs8[0, 0]))
|
| 853 |
+
plot_state_visitation(fig8.add_subplot(gs8[0, 1]))
|
| 854 |
+
plot_regret_curve(fig8.add_subplot(gs8[0, 2]))
|
| 855 |
+
plot_attention_weights(fig8.add_subplot(gs8[0, 3]))
|
| 856 |
+
plot_diffusion_policy(fig8.add_subplot(gs8[1, 0]))
|
| 857 |
+
plot_gnn_rl(fig8.add_subplot(gs8[1, 1]))
|
| 858 |
+
plot_latent_space(fig8.add_subplot(gs8[1, 2]))
|
| 859 |
+
plot_convergence_log(fig8.add_subplot(gs8[1, 3]))
|
| 860 |
+
|
| 861 |
+
# Layout handled by constrained_layout=True
|
| 862 |
+
plt.show()
|
| 863 |
+
|
| 864 |
+
def save_all_graphs(output_dir="graphs"):
|
| 865 |
+
"""Saves each of the 74 RL components as a separate PNG file."""
|
| 866 |
+
if not os.path.exists(output_dir):
|
| 867 |
+
os.makedirs(output_dir)
|
| 868 |
+
|
| 869 |
+
# Component-to-Function Mapping (Total 74 entries as per e.md rows)
|
| 870 |
+
mapping = {
|
| 871 |
+
"Agent-Environment Interaction Loop": plot_agent_env_loop,
|
| 872 |
+
"Markov Decision Process (MDP) Tuple": plot_mdp_graph,
|
| 873 |
+
"State Transition Graph": plot_mdp_graph,
|
| 874 |
+
"Trajectory / Episode Sequence": plot_trajectory,
|
| 875 |
+
"Continuous State/Action Space Visualization": plot_continuous_space,
|
| 876 |
+
"Reward Function / Landscape": plot_reward_landscape,
|
| 877 |
+
"Discount Factor (gamma) Effect": plot_discount_decay,
|
| 878 |
+
"State-Value Function V(s)": plot_value_heatmap,
|
| 879 |
+
"Action-Value Function Q(s,a)": plot_action_value_q,
|
| 880 |
+
"Policy pi(s) or pi(a|s)": plot_policy_arrows,
|
| 881 |
+
"Advantage Function A(s,a)": plot_advantage_function,
|
| 882 |
+
"Optimal Value Function V* / Q*": plot_value_heatmap,
|
| 883 |
+
"Policy Evaluation Backup": plot_backup_diagram,
|
| 884 |
+
"Policy Improvement": plot_policy_improvement,
|
| 885 |
+
"Value Iteration Backup": plot_value_iteration_backup,
|
| 886 |
+
"Policy Iteration Full Cycle": plot_policy_iteration_cycle,
|
| 887 |
+
"Monte Carlo Backup": plot_mc_backup,
|
| 888 |
+
"Monte Carlo Tree (MCTS)": plot_mcts,
|
| 889 |
+
"Importance Sampling Ratio": plot_importance_sampling,
|
| 890 |
+
"TD(0) Backup": plot_td_backup,
|
| 891 |
+
"Bootstrapping (general)": plot_td_backup,
|
| 892 |
+
"n-step TD Backup": plot_nstep_td,
|
| 893 |
+
"TD(lambda) & Eligibility Traces": plot_eligibility_traces,
|
| 894 |
+
"SARSA Update": plot_sarsa_backup,
|
| 895 |
+
"Q-Learning Update": plot_q_learning_backup,
|
| 896 |
+
"Expected SARSA": plot_expected_sarsa_backup,
|
| 897 |
+
"Double Q-Learning / Double DQN": plot_double_q,
|
| 898 |
+
"Dueling DQN Architecture": plot_dueling_dqn,
|
| 899 |
+
"Prioritized Experience Replay": plot_prioritized_replay,
|
| 900 |
+
"Rainbow DQN Components": plot_rainbow_dqn,
|
| 901 |
+
"Linear Function Approximation": plot_linear_fa,
|
| 902 |
+
"Neural Network Layers (MLP, CNN, RNN, Transformer)": plot_nn_layers,
|
| 903 |
+
"Computation Graph / Backpropagation Flow": plot_computation_graph,
|
| 904 |
+
"Target Network": plot_target_network,
|
| 905 |
+
"Policy Gradient Theorem": plot_policy_gradient_flow,
|
| 906 |
+
"REINFORCE Update": plot_reinforce_flow,
|
| 907 |
+
"Baseline / Advantage Subtraction": plot_advantage_scaled_grad,
|
| 908 |
+
"Trust Region (TRPO)": plot_trpo_trust_region,
|
| 909 |
+
"Proximal Policy Optimization (PPO)": plot_ppo_clip,
|
| 910 |
+
"Actor-Critic Architecture": plot_actor_critic_arch,
|
| 911 |
+
"Advantage Actor-Critic (A2C/A3C)": plot_a3c_multi_worker,
|
| 912 |
+
"Soft Actor-Critic (SAC)": plot_sac_arch,
|
| 913 |
+
"Twin Delayed DDPG (TD3)": plot_actor_critic_arch,
|
| 914 |
+
"epsilon-Greedy Strategy": plot_epsilon_decay,
|
| 915 |
+
"Softmax / Boltzmann Exploration": plot_softmax_exploration,
|
| 916 |
+
"Upper Confidence Bound (UCB)": plot_ucb_confidence,
|
| 917 |
+
"Intrinsic Motivation / Curiosity": plot_intrinsic_motivation,
|
| 918 |
+
"Entropy Regularization": plot_entropy_bonus,
|
| 919 |
+
"Options Framework": plot_options_framework,
|
| 920 |
+
"Feudal Networks / Hierarchical Actor-Critic": plot_feudal_networks,
|
| 921 |
+
"Skill Discovery": plot_skill_discovery,
|
| 922 |
+
"Learned Dynamics Model": plot_world_model,
|
| 923 |
+
"Model-Based Planning": plot_model_planning,
|
| 924 |
+
"Imagination-Augmented Agents (I2A)": plot_imagination_rollout,
|
| 925 |
+
"Offline Dataset": plot_offline_rl,
|
| 926 |
+
"Conservative Q-Learning (CQL)": plot_cql_regularization,
|
| 927 |
+
"Multi-Agent Interaction Graph": plot_multi_agent_interaction,
|
| 928 |
+
"Centralized Training Decentralized Execution (CTDE)": plot_ctde,
|
| 929 |
+
"Cooperative / Competitive Payoff Matrix": plot_payoff_matrix,
|
| 930 |
+
"Reward Inference": plot_irl_reward_inference,
|
| 931 |
+
"Generative Adversarial Imitation Learning (GAIL)": plot_gail_flow,
|
| 932 |
+
"Meta-RL Architecture": plot_meta_rl_nested_loop,
|
| 933 |
+
"Task Distribution Visualization": plot_task_distribution,
|
| 934 |
+
"Experience Replay Buffer": plot_replay_buffer,
|
| 935 |
+
"State Visitation / Occupancy Measure": plot_state_visitation,
|
| 936 |
+
"Learning Curve": plot_learning_curve,
|
| 937 |
+
"Regret / Cumulative Regret": plot_regret_curve,
|
| 938 |
+
"Attention Mechanisms (Transformers in RL)": plot_attention_weights,
|
| 939 |
+
"Diffusion Policy": plot_diffusion_policy,
|
| 940 |
+
"Graph Neural Networks for RL": plot_gnn_rl,
|
| 941 |
+
"World Model / Latent Space": plot_latent_space,
|
| 942 |
+
"Convergence Analysis Plots": plot_convergence_log
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
import sys
|
| 946 |
+
|
| 947 |
+
for name, func in mapping.items():
|
| 948 |
+
# Sanitize filename
|
| 949 |
+
filename = re.sub(r'[^a-zA-Z0-9]', '_', name.lower()).strip('_')
|
| 950 |
+
filename = re.sub(r'_+', '_', filename) + ".png"
|
| 951 |
+
filepath = os.path.join(output_dir, filename)
|
| 952 |
+
|
| 953 |
+
print(f"Generating: {filename} ...")
|
| 954 |
+
|
| 955 |
+
plt.close('all')
|
| 956 |
+
|
| 957 |
+
if func == plot_reward_landscape:
|
| 958 |
+
fig = plt.figure(figsize=(10, 8))
|
| 959 |
+
gs = GridSpec(1, 1, figure=fig)
|
| 960 |
+
func(fig, gs)
|
| 961 |
+
plt.savefig(filepath, bbox_inches='tight', dpi=100)
|
| 962 |
+
plt.close(fig)
|
| 963 |
+
continue
|
| 964 |
+
|
| 965 |
+
fig, ax = plt.subplots(figsize=(10, 8), constrained_layout=True)
|
| 966 |
+
func(ax)
|
| 967 |
+
plt.savefig(filepath, bbox_inches='tight', dpi=100)
|
| 968 |
+
plt.close(fig)
|
| 969 |
+
|
| 970 |
+
print(f"\n[SUCCESS] Saved {len(mapping)} graphs to '{output_dir}/' directory.")
|
| 971 |
+
|
| 972 |
+
if __name__ == "__main__":
|
| 973 |
+
import sys
|
| 974 |
+
if "--save" in sys.argv:
|
| 975 |
+
save_all_graphs()
|
| 976 |
+
else:
|
| 977 |
+
main()
|
graphs/action_value_function_q_s_a.png
ADDED
|
graphs/actor_critic_architecture.png
ADDED
|
graphs/advantage_actor_critic_a2c_a3c.png
ADDED
|
graphs/advantage_function_a_s_a.png
ADDED
|
graphs/agent_environment_interaction_loop.png
ADDED
|
graphs/attention_mechanisms_transformers_in_rl.png
ADDED
|
graphs/baseline_advantage_subtraction.png
ADDED
|
graphs/bootstrapping_general.png
ADDED
|
graphs/centralized_training_decentralized_execution_ctde.png
ADDED
|
graphs/computation_graph_backpropagation_flow.png
ADDED
|
graphs/conservative_q_learning_cql.png
ADDED
|
graphs/continuous_state_action_space_visualization.png
ADDED
|
graphs/convergence_analysis_plots.png
ADDED
|
graphs/cooperative_competitive_payoff_matrix.png
ADDED
|
graphs/diffusion_policy.png
ADDED
|
graphs/discount_factor_gamma_effect.png
ADDED
|
graphs/double_q_learning_double_dqn.png
ADDED
|
graphs/dueling_dqn_architecture.png
ADDED
|
graphs/entropy_regularization.png
ADDED
|
graphs/epsilon_greedy_strategy.png
ADDED
|
graphs/expected_sarsa.png
ADDED
|
graphs/experience_replay_buffer.png
ADDED
|
graphs/feudal_networks_hierarchical_actor_critic.png
ADDED
|
graphs/generative_adversarial_imitation_learning_gail.png
ADDED
|
graphs/graph_neural_networks_for_rl.png
ADDED
|
graphs/imagination_augmented_agents_i2a.png
ADDED
|
graphs/importance_sampling_ratio.png
ADDED
|
graphs/intrinsic_motivation_curiosity.png
ADDED
|
graphs/learned_dynamics_model.png
ADDED
|
graphs/learning_curve.png
ADDED
|
graphs/linear_function_approximation.png
ADDED
|
graphs/markov_decision_process_mdp_tuple.png
ADDED
|
graphs/meta_rl_architecture.png
ADDED
|
graphs/model_based_planning.png
ADDED
|
graphs/monte_carlo_backup.png
ADDED
|
graphs/monte_carlo_tree_mcts.png
ADDED
|
graphs/multi_agent_interaction_graph.png
ADDED
|
graphs/n_step_td_backup.png
ADDED
|
graphs/neural_network_layers_mlp_cnn_rnn_transformer.png
ADDED
|
graphs/offline_dataset.png
ADDED
|
graphs/optimal_value_function_v_q.png
ADDED
|
graphs/options_framework.png
ADDED
|
graphs/policy_evaluation_backup.png
ADDED
|
graphs/policy_gradient_theorem.png
ADDED
|
graphs/policy_improvement.png
ADDED
|
graphs/policy_iteration_full_cycle.png
ADDED
|
graphs/policy_pi_s_or_pi_a_s.png
ADDED
|