pdppo / code /Lot-sizing /test_functions /plotFunctions.py
leokana's picture
first commit
a241478
# -*- coding: utf-8 -*-
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
def plot_policy(agent, file_name=None):
if agent.env.n_machines > 1:
print('Impossible to print for n_machine > 1')
return
cmap = plt.cm.get_cmap('viridis', 3)
policy_map = np.zeros(
(
agent.env.max_inventory_level[0]+1,
agent.env.max_inventory_level[1]+1,
agent.env.n_items+1
)
)
for i in range(agent.env.max_inventory_level[0]+1):
for j in range(agent.env.max_inventory_level[1]+1):
for k in range(agent.env.n_items+1):
# TODO: end this general funtion
obs = np.expand_dims(np.array([i, j, k]), axis = 0)
action = agent.get_action(obs,deterministic=True)
policy_map[i,j,k] = action
agent.policy = policy_map
fig, axs = plt.subplots(1, agent.POSSIBLE_STATES)
fig.suptitle('Found Policy')
for i, ax in enumerate(axs):
ax.set_title(f'Setup {i}')
im = ax.pcolormesh(
agent.policy[:,:,i], cmap = cmap, edgecolors='k', linewidth=2
)
im.set_clim(0, agent.POSSIBLE_STATES - 1)
ax.set_xlabel('I2')
if i == 0:
ax.set_ylabel('I1')
# COLOR BAR:
bound = [0,1,2]
# Creating 8 Patch instances
fig.subplots_adjust(bottom=0.2)
ax.legend(
[mpatches.Patch(color=cmap(b)) for b in bound],
['{}'.format(i) for i in range(3)],
loc='upper center', bbox_to_anchor=(-0.8,-0.13),
fancybox=True, shadow=True, ncol=3
)
if file_name:
fig.savefig(
os.path.join('results', file_name),
bbox_inches='tight'
)
else:
plt.show()