File size: 1,798 Bytes
a241478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches


def plot_policy(agent, file_name=None):
    if agent.env.n_machines > 1:
        print('Impossible to print for n_machine > 1')
        return
    cmap = plt.cm.get_cmap('viridis', 3) 
    policy_map = np.zeros(
        (
            agent.env.max_inventory_level[0]+1,
            agent.env.max_inventory_level[1]+1,
            agent.env.n_items+1
        )
    )
    for i in range(agent.env.max_inventory_level[0]+1):   
        for j in range(agent.env.max_inventory_level[1]+1):
            for k in range(agent.env.n_items+1):
                # TODO: end this general funtion
                obs = np.expand_dims(np.array([i, j, k]), axis = 0)
                action = agent.get_action(obs,deterministic=True)
                policy_map[i,j,k] = action
    agent.policy = policy_map

    fig, axs = plt.subplots(1, agent.POSSIBLE_STATES)
    fig.suptitle('Found Policy')
    for i, ax in enumerate(axs):
        ax.set_title(f'Setup {i}')
        im = ax.pcolormesh(
            agent.policy[:,:,i], cmap = cmap, edgecolors='k', linewidth=2
        )
        im.set_clim(0, agent.POSSIBLE_STATES - 1)
        ax.set_xlabel('I2')
        if i == 0:
            ax.set_ylabel('I1')

    # COLOR BAR:
    bound = [0,1,2]
    # Creating 8 Patch instances
    fig.subplots_adjust(bottom=0.2)
    ax.legend(
        [mpatches.Patch(color=cmap(b)) for b in bound],
        ['{}'.format(i) for i in range(3)],
        loc='upper center', bbox_to_anchor=(-0.8,-0.13),
        fancybox=True, shadow=True, ncol=3
    )
    if file_name:
        fig.savefig(
            os.path.join('results', file_name),
            bbox_inches='tight'
        )
    else:
        plt.show()