| |
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
|
|
|
|
| def plot_policy(agent, file_name=None): |
| if agent.env.n_machines > 1: |
| print('Impossible to print for n_machine > 1') |
| return |
| cmap = plt.cm.get_cmap('viridis', 3) |
| policy_map = np.zeros( |
| ( |
| agent.env.max_inventory_level[0]+1, |
| agent.env.max_inventory_level[1]+1, |
| agent.env.n_items+1 |
| ) |
| ) |
| for i in range(agent.env.max_inventory_level[0]+1): |
| for j in range(agent.env.max_inventory_level[1]+1): |
| for k in range(agent.env.n_items+1): |
| |
| obs = np.expand_dims(np.array([i, j, k]), axis = 0) |
| action = agent.get_action(obs,deterministic=True) |
| policy_map[i,j,k] = action |
| agent.policy = policy_map |
|
|
| fig, axs = plt.subplots(1, agent.POSSIBLE_STATES) |
| fig.suptitle('Found Policy') |
| for i, ax in enumerate(axs): |
| ax.set_title(f'Setup {i}') |
| im = ax.pcolormesh( |
| agent.policy[:,:,i], cmap = cmap, edgecolors='k', linewidth=2 |
| ) |
| im.set_clim(0, agent.POSSIBLE_STATES - 1) |
| ax.set_xlabel('I2') |
| if i == 0: |
| ax.set_ylabel('I1') |
|
|
| |
| bound = [0,1,2] |
| |
| fig.subplots_adjust(bottom=0.2) |
| ax.legend( |
| [mpatches.Patch(color=cmap(b)) for b in bound], |
| ['{}'.format(i) for i in range(3)], |
| loc='upper center', bbox_to_anchor=(-0.8,-0.13), |
| fancybox=True, shadow=True, ncol=3 |
| ) |
| if file_name: |
| fig.savefig( |
| os.path.join('results', file_name), |
| bbox_inches='tight' |
| ) |
| else: |
| plt.show() |
|
|