| |
| """ |
| Created on Wed Mar 1 00:43:49 2023 |
| |
| @author: leona |
| """ |
|
|
| import torch |
| import numpy as np |
| import torch.nn as nn |
| from torch.distributions import MultivariateNormal |
| from torch.distributions import Categorical |
|
|
| |
| print("============================================================================================") |
| |
| device = torch.device('cpu') |
| if(torch.cuda.is_available()): |
| device = torch.device('cuda:0') |
| torch.cuda.empty_cache() |
| print("Device set to : " + str(torch.cuda.get_device_name(device))) |
| else: |
| print("Device set to : cpu") |
| print("============================================================================================") |
|
|
|
|
| |
| class RolloutBuffer: |
| def __init__(self): |
| self.actions = [] |
| self.actions_pre = [] |
| self.actions_post = [] |
| self.states = [] |
| self.pre_states = [] |
| self.post_states = [] |
| self.logprobs = [] |
| self.logprobs_pre = [] |
| self.logprobs_post = [] |
| self.rewards = [] |
| self.rewards_pre = [] |
| self.rewards_post = [] |
| self.state_values = [] |
| self.state_values_post = [] |
| self.is_terminals = [] |
| |
| def clear(self): |
| del self.actions[:] |
| del self.actions_pre[:] |
| del self.actions_post[:] |
| del self.states[:] |
| del self.pre_states[:] |
| del self.post_states[:] |
| del self.logprobs[:] |
| del self.logprobs_pre[:] |
| del self.logprobs_post[:] |
| del self.rewards[:] |
| del self.rewards_pre[:] |
| del self.rewards_post[:] |
| del self.state_values[:] |
| del self.state_values_post[:] |
| del self.is_terminals[:] |
|
|
| class ActorCritic(nn.Module): |
| def __init__(self, state_dim,state_dim_pre,state_dim_post, action_dim, has_continuous_action_space, action_std_init): |
| super(ActorCritic, self).__init__() |
| |
| self.action_dim = action_dim |
| |
| |
| self.actor = nn.Sequential( |
| nn.Linear(state_dim, 128), |
| nn.Linear(128, 128), |
| nn.Linear(128, self.action_dim.nvec.sum()) |
| ) |
| |
| self.actor_pre = nn.Sequential( |
| nn.Linear(state_dim_pre, 128), |
| nn.Linear(128, 128), |
| nn.Linear(128, self.action_dim.nvec.sum()) |
| ) |
| self.actor_post = nn.Sequential( |
| nn.Linear(state_dim_post, 128), |
| nn.Linear(128, 128), |
| nn.Linear(128, self.action_dim.nvec.sum()) |
| ) |
| |
| |
| |
| self.critic_pre = nn.Sequential( |
| nn.Linear(state_dim_pre, 128), |
| |
| |
| nn.Tanh(), |
| nn.Linear(128, 1) |
| ) |
| |
| self.critic_post = nn.Sequential( |
| nn.Linear(state_dim_post, 128), |
| |
| |
| nn.Tanh(), |
| nn.Linear(128, 1) |
| ) |
| |
| def forward(self, state): |
| raise NotImplementedError |
| |
| |
| |
| def set_action_std(self, new_action_std): |
| if self.has_continuous_action_space: |
| self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device) |
| else: |
| print("--------------------------------------------------------------------------------------------") |
| print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy") |
| print("--------------------------------------------------------------------------------------------") |
|
|
|
|
| |
| def act(self, state): |
|
|
| |
| x = nn.functional.relu(self.fc2(nn.functional.relu(self.fc1(state)))) |
| logits = self.actor(x) |
| action_probs = nn.functional.softmax(logits, dim=-1) |
| dist = Categorical(action_probs.view(len(self.action_dim.nvec),-1)) |
|
|
| action = dist.sample() |
| action_logprob = dist.log_prob(action) |
|
|
| return action.detach(), action_logprob.detach() |
| |
| def evaluate(self, state, pre_state, post_state, action): |
|
|
| |
| x = nn.functional.relu(self.fc2(nn.functional.relu(self.fc1(state)))) |
| logits = self.actor(x) |
| action_probs = nn.functional.softmax(logits, dim=-1) |
| dist = Categorical(action_probs.view(state.shape[0],len(self.action_dim.nvec),-1)) |
| |
| |
| |
| action_logprobs = dist.log_prob(action) |
| dist_entropy = dist.entropy() |
| state_values_pre = self.critic_pre(pre_state) |
| state_values_post = self.critic_post(post_state) |
| |
| return action_logprobs, state_values_pre,state_values_post, dist_entropy |
|
|
|
|
| class PDPPO: |
| def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, env, action_std_init=0.6): |
|
|
| self.has_continuous_action_space = has_continuous_action_space |
|
|
| if has_continuous_action_space: |
| self.action_std = action_std_init |
| |
| self.env = env |
| |
| self.reward_old_pre = -np.inf |
| self.reward_old_post = -np.inf |
| |
| self.gamma = gamma |
| self.eps_clip = eps_clip |
| self.K_epochs = K_epochs |
| |
| self.buffer = RolloutBuffer() |
| |
| state_dim_pre = self.env.n_machines |
| state_dim_post = self.env.n_items |
| |
| self.policy = ActorCritic(state_dim,state_dim_pre,state_dim_post, action_dim, has_continuous_action_space, action_std_init).to(device) |
| self.optimizer = torch.optim.Adam([ |
| {'params': self.policy.actor.parameters(), 'lr': lr_actor}, |
| {'params': self.policy.critic.parameters(), 'lr': lr_critic*10}, |
| {'params': self.policy.critic_post.parameters(), 'lr': lr_critic*1} |
| ]) |
| |
| self.policy_old = ActorCritic(state_dim,state_dim_pre,state_dim_post, action_dim, has_continuous_action_space, action_std_init).to(device) |
| self.policy_old.load_state_dict(self.policy.state_dict()) |
| |
| self.MseLoss = nn.MSELoss() |
|
|
| |
| def get_post_state(self, action, machine_setup, inventory_level): |
| setup_loss = np.zeros(self.env.n_machines, dtype=int) |
| setup_costs = np.zeros(self.env.n_machines) |
| |
| for m in range(self.env.n_machines): |
| if action[m] != 0: |
| |
| if machine_setup[m] != action[m] and action[m] != 0: |
| setup_costs[m] = self.env.setup_costs[m][action[m] - 1] |
| setup_loss[m] = self.env.setup_loss[m][action[m] - 1] |
| machine_setup[m] = action[m] |
| |
| production = self.env.machine_production_matrix[m][action[m] - 1] - setup_loss[m] |
| inventory_level[action[m] - 1] += production |
| else: |
| machine_setup[m] = 0 |
| |
| return machine_setup, inventory_level, setup_costs |
| |
| def select_action(self, state): |
| with torch.no_grad(): |
| state = torch.FloatTensor(state).to(device) |
| action, action_logprob = self.policy_old.act(state) |
| |
| |
| pre_state = state[self.env.n_items:self.env.n_items+self.env.n_machines].clone() |
| |
| machine_setup, inventory_level, setup_cost = self.get_post_state(action, state[self.env.n_items:self.env.n_items+self.env.n_machines], state[0:self.env.n_items]) |
| |
| post_state = inventory_level.clone() |
| |
| with torch.no_grad(): |
| action_pre, action_logprob_pre = self.policy_old.act_pre(pre_state) |
| action_post, action_logprob_post = self.policy_old.act_post(post_state) |
| |
| |
| self.buffer.states.append(state) |
| self.buffer.pre_states.append(pre_state) |
| self.buffer.post_states.append(post_state) |
| self.buffer.actions.append(action) |
| self.buffer.actions_pre.append(action_pre) |
| self.buffer.actions_post.append(action_post) |
| self.buffer.logprobs.append(action_logprob) |
| self.buffer.logprobs.append(action_logprob_pre) |
| self.buffer.logprobs.append(action_logprob_post) |
| |
| with torch.no_grad(): |
| state_val = self.policy_old.critic(pre_state).detach() |
| state_val_post = self.policy_old.critic_post(post_state).detach() |
| |
| self.buffer.state_values.append(state_val) |
| self.buffer.state_values_post.append(state_val_post) |
| |
| if self.has_continuous_action_space: |
| return action.detach().cpu().numpy().flatten() |
| |
| else: |
| return action.numpy() |
| |
| def update(self): |
| |
| |
| |
| rewards = [] |
| discounted_reward = 0 |
| for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)): |
| if is_terminal: |
| discounted_reward = 0 |
| discounted_reward = reward + (self.gamma * discounted_reward) |
| rewards.insert(0, discounted_reward) |
| |
| |
| rewards = torch.tensor(rewards, dtype=torch.float32).to(device) |
| rewards = rewards/(-rewards).max() |
| |
| rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7) |
| |
| |
| rewards_pre = [] |
| discounted_reward = 0 |
| for reward_pre, is_terminal in zip(reversed(self.buffer.rewards_pre), reversed(self.buffer.is_terminals)): |
| if is_terminal: |
| discounted_reward = 0 |
| discounted_reward = reward_pre + (self.gamma * discounted_reward) |
| rewards_pre.insert(0, discounted_reward) |
| |
| |
| rewards_pre = torch.tensor(rewards_pre, dtype=torch.float32).to(device) |
| |
| |
| rewards_pre = (rewards_pre - rewards_pre.mean()) / (rewards_pre.std() + 1e-7) |
| |
| |
| rewards_post = [] |
| discounted_reward = 0 |
| for reward_post, is_terminal in zip(reversed(self.buffer.rewards_post), reversed(self.buffer.is_terminals)): |
| if is_terminal: |
| discounted_reward = 0 |
| discounted_reward = reward_post + (self.gamma * discounted_reward) |
| rewards_post.insert(0, discounted_reward) |
| |
| |
| rewards_post = torch.tensor(rewards_post, dtype=torch.float32).to(device) |
| |
| |
| rewards_post = (rewards_post - rewards_post.mean()) / (rewards_post.std() + 1e-7) |
| |
| |
| |
| |
| |
| |
| old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device) |
| old_pre_states = torch.squeeze(torch.stack(self.buffer.pre_states, dim=0)).detach().to(device) |
| old_post_states = torch.squeeze(torch.stack(self.buffer.post_states, dim=0)).detach().to(device) |
| old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device) |
| old_actions_pre = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device) |
| old_actions_post = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device) |
| old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device) |
| old_logprobs_pre = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device) |
| old_logprobs_post = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device) |
| old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device) |
| old_state_values_post = torch.squeeze(torch.stack(self.buffer.state_values_post, dim=0)).detach().to(device) |
|
|
| |
| advantages_post = rewards_post.detach() - old_state_values_post.detach() |
| advantages_pre = rewards_pre.detach() - old_state_values.detach() |
| advantages = rewards.detach() - old_state_values_post.detach() - old_state_values.detach() |
| |
| sum_loss = 0 |
| |
| |
| for i in range(self.K_epochs): |
| |
| |
| logprobs, logprobs_pre, logprobs_post, state_values, state_values_post, dist_entropy, dist_entropy_pre, dist_entropy_post = self.policy.evaluate(old_states, old_pre_states, old_post_states, old_actions, old_actions_pre, old_actions_post) |
|
|
| |
| ratios = torch.exp(logprobs - old_logprobs) |
| |
| |
| surr1 = ratios * advantages.unsqueeze(1) |
| surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages.unsqueeze(1) |
| |
| surr = -torch.min(surr1, surr2) |
| |
| loss = surr + 0.5 * self.MseLoss(old_state_values_post, old_state_values) - 0.01*dist_entropy |
| |
| |
| self.optimizer.zero_grad() |
| loss.mean().backward(retain_graph=True) |
| self.optimizer.step() |
|
|
| |
| |
| print('Last Loss {}'.format(loss.sum().item())) |
| |
| self.policy_old.load_state_dict(self.policy.state_dict()) |
|
|
| |
| self.buffer.clear() |
| |
| def save(self, checkpoint_path): |
| torch.save(self.policy_old.state_dict(), checkpoint_path) |
| |
| def load(self, checkpoint_path): |
| self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |
| self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |