File size: 1,798 Bytes
a241478 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | # -*- coding: utf-8 -*-
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
def plot_policy(agent, file_name=None):
if agent.env.n_machines > 1:
print('Impossible to print for n_machine > 1')
return
cmap = plt.cm.get_cmap('viridis', 3)
policy_map = np.zeros(
(
agent.env.max_inventory_level[0]+1,
agent.env.max_inventory_level[1]+1,
agent.env.n_items+1
)
)
for i in range(agent.env.max_inventory_level[0]+1):
for j in range(agent.env.max_inventory_level[1]+1):
for k in range(agent.env.n_items+1):
# TODO: end this general funtion
obs = np.expand_dims(np.array([i, j, k]), axis = 0)
action = agent.get_action(obs,deterministic=True)
policy_map[i,j,k] = action
agent.policy = policy_map
fig, axs = plt.subplots(1, agent.POSSIBLE_STATES)
fig.suptitle('Found Policy')
for i, ax in enumerate(axs):
ax.set_title(f'Setup {i}')
im = ax.pcolormesh(
agent.policy[:,:,i], cmap = cmap, edgecolors='k', linewidth=2
)
im.set_clim(0, agent.POSSIBLE_STATES - 1)
ax.set_xlabel('I2')
if i == 0:
ax.set_ylabel('I1')
# COLOR BAR:
bound = [0,1,2]
# Creating 8 Patch instances
fig.subplots_adjust(bottom=0.2)
ax.legend(
[mpatches.Patch(color=cmap(b)) for b in bound],
['{}'.format(i) for i in range(3)],
loc='upper center', bbox_to_anchor=(-0.8,-0.13),
fancybox=True, shadow=True, ncol=3
)
if file_name:
fig.savefig(
os.path.join('results', file_name),
bbox_inches='tight'
)
else:
plt.show()
|