| |
| """ |
| Created on Wed Mar 1 00:43:49 2023 |
| |
| @author: leona |
| """ |
|
|
| import numpy as np |
| import torch |
| from sklearn.cluster import KMeans |
| import torch.nn as nn |
| import torch.nn.init as init |
| import torch.nn.functional as F |
| from torch.distributions import MultivariateNormal |
| from torch.distributions import Categorical |
|
|
| |
| |
| |
| device = torch.device('cpu') |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class NegReLU(nn.Module): |
| def forward(self, x): |
| return -torch.relu(x) |
|
|
| import torch |
| import torch.nn as nn |
|
|
| class NoisyLinear(nn.Module): |
| def __init__(self, in_features, out_features, std_init=0.4): |
| super(NoisyLinear, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.std_init = std_init |
|
|
| |
| self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features)) |
| self.weight_sigma = nn.Parameter(torch.Tensor(out_features, in_features) * std_init) |
| self.bias_mu = nn.Parameter(torch.empty(out_features)) |
| self.bias_sigma = nn.Parameter(torch.empty(out_features)) |
|
|
| |
| self.reset_parameters() |
|
|
| |
| self.register_buffer('weight_epsilon', torch.empty(out_features, in_features)) |
| self.register_buffer('bias_epsilon', torch.empty(out_features)) |
|
|
| def reset_parameters(self): |
| |
| nn.init.kaiming_uniform_(self.weight_mu, a=np.sqrt(5)) |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_mu) |
| bound = 1 / np.sqrt(fan_in) |
| nn.init.uniform_(self.bias_mu, -bound, bound) |
|
|
| |
| nn.init.constant_(self.weight_sigma, self.std_init) |
| nn.init.constant_(self.bias_sigma, self.std_init) |
|
|
| def forward(self, input, noise_scale=1.0): |
| |
| self.weight_epsilon.normal_() |
| self.bias_epsilon.normal_() |
|
|
| |
| weight_noise = self.weight_epsilon.clamp(-2, 2) * noise_scale |
| bias_noise = self.bias_epsilon.clamp(-2, 2) * noise_scale |
|
|
| |
| weight = self.weight_mu + self.weight_sigma * weight_noise |
| bias = self.bias_mu + self.bias_sigma * bias_noise |
|
|
| return F.linear(input, weight, bias) |
|
|
|
|
|
|
| |
| class RolloutBuffer: |
| def __init__(self): |
| self.actions = [] |
| self.states = [] |
| self.post_states = [] |
| self.logprobs = [] |
| self.rewards = [] |
| self.post_rewards = [] |
| self.state_values = [] |
| self.state_values_post = [] |
| self.is_terminals = [] |
| |
| def clear(self,lag): |
| self.actions = self.actions[lag:] |
| self.states = self.states[lag:] |
| self.post_states = self.post_states[lag:] |
| self.logprobs = self.logprobs[lag:] |
| self.rewards = self.rewards[lag:] |
| self.post_rewards = self.post_rewards[lag:] |
| self.state_values = self.state_values[lag:] |
| self.state_values_post = self.state_values_post[lag:] |
| self.is_terminals = self.is_terminals[lag:] |
|
|
| class ActorCritic(nn.Module): |
| def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init, noise_decay_rate=0.9976): |
| super(ActorCritic, self).__init__() |
| self.noise_scale = 1.0 |
| self.noise_decay_rate = noise_decay_rate |
| self.has_continuous_action_space = has_continuous_action_space |
| |
| if has_continuous_action_space: |
| self.action_dim = action_dim |
| self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device) |
| |
| if has_continuous_action_space : |
| self.actor = nn.Sequential( |
| nn.Linear(state_dim, 128), |
| nn.Tanh(), |
| nn.Linear(128, 128), |
| nn.Tanh(), |
| nn.Linear(128, action_dim), |
| nn.Tanh() |
| ) |
| else: |
| |
| self.action_dim = action_dim |
| self.fc1 = nn.Linear(state_dim, 128) |
| self.fc2 = nn.Linear(128, 128) |
| self.actor = nn.Linear(128, self.action_dim.nvec.sum()) |
| |
| |
| |
| self.critic = nn.Sequential( |
| nn.Linear(state_dim, 128), |
| nn.Tanh(), |
| nn.Linear(128, 128), |
| nn.Tanh(), |
| nn.Linear(128, 1), |
| ) |
| |
| self.critic_post = nn.Sequential( |
| nn.Linear(state_dim, 128), |
| nn.Tanh(), |
| nn.Linear(128, 128), |
| nn.Tanh(), |
| nn.Linear(128, 1), |
| ) |
| |
| def decay_noise(self): |
| self.noise_scale *= self.noise_decay_rate |
| |
| def _initialize_actor(self, m): |
| if isinstance(m, nn.Linear): |
| |
| init.kaiming_uniform_(m.weight, nonlinearity='relu') |
| if m.bias is not None: |
| init.zeros_(m.bias) |
|
|
| def _initialize_critic(self, m): |
| if isinstance(m, nn.Linear): |
| |
| init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| init.zeros_(m.bias) |
|
|
| def forward(self, state): |
| raise NotImplementedError |
| |
| |
| |
| def set_action_std(self, new_action_std): |
| if self.has_continuous_action_space: |
| self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device) |
| else: |
| print("--------------------------------------------------------------------------------------------") |
| print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy") |
| print("--------------------------------------------------------------------------------------------") |
|
|
|
|
| |
| def act(self, state): |
| if self.has_continuous_action_space: |
| action_mean = self.actor(state) |
| cov_mat = torch.diag(self.action_var).unsqueeze(dim=0) |
| dist = MultivariateNormal(action_mean, cov_mat) |
| else: |
| x = F.relu(self.fc1(state)) |
| x = F.relu(self.fc2(x)) |
| logits = self.actor(x) |
| logits_shaped = logits.view(len(self.action_dim.nvec), self.action_dim.nvec.max()) |
| action_probs = nn.functional.softmax(logits_shaped, dim=-1) |
| dist = Categorical(action_probs) |
|
|
|
|
| action = dist.sample() |
| action_logprob = dist.log_prob(action) |
| |
| return action.detach(), action_logprob.detach() |
| |
| def evaluate(self, state, post_state, action): |
|
|
| if self.has_continuous_action_space: |
| action_mean = self.actor(state) |
| |
| action_var = self.action_var.expand_as(action_mean) |
| cov_mat = torch.diag_embed(action_var).to(device) |
| dist = MultivariateNormal(action_mean, cov_mat) |
| |
| |
| if self.action_dim == 1: |
| action = action.reshape(-1, self.action_dim) |
| else: |
| x = F.relu(self.fc1(state)) |
| x = F.relu(self.fc2(x)) |
| logits = self.actor(x) |
| logits_shaped = logits.view(-1,len(self.action_dim.nvec), self.action_dim.nvec.max()) |
| action_probs = nn.functional.softmax(logits_shaped, dim=-1) |
| dist = Categorical(action_probs) |
| |
| action_logprobs = dist.log_prob(action) |
| dist_entropy = dist.entropy() |
| state_values = self.critic(state) |
| state_values_post = self.critic_post(post_state) |
| |
| return action_logprobs, state_values, state_values_post, dist_entropy |
|
|
|
|
| class PDPPO: |
| def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, env, has_continuous_action_space, tau, action_std_init=0.6): |
|
|
| self.has_continuous_action_space = has_continuous_action_space |
|
|
| if has_continuous_action_space: |
| self.action_std = action_std_init |
| |
| self.tau = tau |
| self.env = env |
| self.gamma = gamma |
| self.eps_clip = eps_clip |
| self.K_epochs = K_epochs |
| |
| self.buffer = RolloutBuffer() |
|
|
| self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device) |
| self.policy.actor.apply(self.policy._initialize_actor) |
| self.policy.critic.apply(self.policy._initialize_critic) |
| self.policy.critic_post.apply(self.policy._initialize_critic) |
|
|
| self.optimizer_actor = torch.optim.Adam(self.policy.actor.parameters(), lr=lr_actor) |
| self.optimizer_critic = torch.optim.Adam(self.policy.critic.parameters(), lr=lr_critic) |
| self.optimizer_critic_post = torch.optim.Adam(self.policy.critic_post.parameters(), lr=lr_critic) |
| |
| self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device) |
| self.policy_old.load_state_dict(self.policy.state_dict()) |
| |
| self.MseLoss = nn.MSELoss() |
|
|
| def set_action_std(self, new_action_std): |
| if self.has_continuous_action_space: |
| self.action_std = new_action_std |
| self.policy.set_action_std(new_action_std) |
| self.policy_old.set_action_std(new_action_std) |
| else: |
| print("--------------------------------------------------------------------------------------------") |
| print("WARNING : Calling PDPPO::set_action_std() on discrete action space policy") |
| print("--------------------------------------------------------------------------------------------") |
|
|
| def decay_action_std(self, action_std_decay_rate, min_action_std): |
| print("--------------------------------------------------------------------------------------------") |
| if self.has_continuous_action_space: |
| self.action_std = self.action_std - action_std_decay_rate |
| self.action_std = round(self.action_std, 4) |
| if (self.action_std <= min_action_std): |
| self.action_std = min_action_std |
| print("setting actor output action_std to min_action_std : ", self.action_std) |
| else: |
| print("setting actor output action_std to : ", self.action_std) |
| self.set_action_std(self.action_std) |
|
|
| else: |
| print("WARNING : Calling PDPPO::decay_action_std() on discrete action space policy") |
| print("--------------------------------------------------------------------------------------------") |
| |
| def get_post_state(self, action, machine_setup, inventory_level): |
| setup_loss = np.zeros(self.env.n_machines, dtype=int) |
| setup_costs = np.zeros(self.env.n_machines) |
| |
| for m in range(self.env.n_machines): |
| if action[m] != 0: |
| |
| if machine_setup[m] != action[m] and action[m] != 0: |
| setup_costs[m] = self.env.setup_costs[m][action[m] - 1] |
| setup_loss[m] = self.env.setup_loss[m][action[m] - 1] |
| machine_setup[m] = action[m] |
| |
| production = self.env.machine_production_matrix[m][action[m] - 1] - setup_loss[m] |
| inventory_level[action[m] - 1] += production |
| else: |
| machine_setup[m] = 0 |
| |
| return machine_setup, inventory_level, setup_costs |
| |
| def select_action(self, state,tau): |
|
|
| if self.has_continuous_action_space: |
| with torch.no_grad(): |
| state = torch.FloatTensor(state).to(device) |
| action, action_logprob, state_val = self.policy_old.act(state,tau) |
|
|
| self.buffer.states.append(state) |
| self.buffer.actions.append(action) |
| self.buffer.logprobs.append(action_logprob) |
| self.buffer.state_values.append(state_val) |
|
|
| return action.detach().numpy().flatten() |
| else: |
| with torch.no_grad(): |
| state = torch.FloatTensor(state).to(device) |
| action, action_logprob = self.policy_old.act(state) |
| |
| |
| machine_setup, inventory_level, setup_cost = self.get_post_state(action=action.clone(), |
| machine_setup = self.env.machine_setup.copy(), |
| inventory_level = state[0:self.env.n_items].clone()) |
| |
| post_state = state.clone() |
| post_state[0:self.env.n_items] = inventory_level.clone() |
| |
| post_state = torch.FloatTensor(post_state).to(device) |
| |
| self.buffer.states.append(state) |
| self.buffer.post_states.append(post_state) |
| self.buffer.actions.append(action) |
| self.buffer.logprobs.append(action_logprob) |
| post_rewards = torch.FloatTensor([-sum(setup_cost)]) |
| self.buffer.post_rewards.append(post_rewards) |
| |
| with torch.no_grad(): |
| state_val = self.policy_old.critic(state) |
| state_val_post = self.policy_old.critic_post(post_state) |
| |
| self.buffer.state_values.append(state_val) |
| self.buffer.state_values_post.append(state_val_post) |
|
|
| return action.numpy(), post_rewards.numpy() |
|
|
| def update(self): |
|
|
| rewards = [] |
| discounted_reward = 0 |
| for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)): |
| if is_terminal: |
| discounted_reward = 0 |
| discounted_reward = reward + (self.gamma * discounted_reward) |
| rewards.insert(0, discounted_reward) |
| |
| |
| rewards = torch.tensor(rewards, dtype=torch.float32).to(device) |
|
|
| |
|
|
| post_rewards = [] |
| discounted_reward = 0 |
| for reward, is_terminal in zip(reversed(self.buffer.post_rewards), reversed(self.buffer.is_terminals)): |
| if is_terminal: |
| discounted_reward = 0 |
| discounted_reward = reward + (self.gamma * discounted_reward) |
| post_rewards.insert(0, discounted_reward) |
| |
|
|
| post_rewards = torch.tensor(post_rewards, dtype=torch.float32).to(device) |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
| old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device) |
| old_post_states = torch.squeeze(torch.stack(self.buffer.post_states, dim=0)).detach().to(device) |
| old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device) |
| old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device) |
| old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device) |
| old_state_values_post = torch.squeeze(torch.stack(self.buffer.state_values_post, dim=0)).detach().to(device) |
|
|
| |
| advantages_current = rewards - old_state_values |
| advantages_post = post_rewards - old_state_values_post |
|
|
| advantages = torch.max(advantages_current, advantages_post) |
| |
|
|
| |
| for _ in range(self.K_epochs): |
|
|
| |
| logprobs, state_values, post_state_values, dist_entropy, = self.policy.evaluate(old_states, old_post_states, old_actions) |
|
|
| |
| ratios = torch.exp(logprobs - old_logprobs.detach()) |
|
|
| |
| surr1 = ratios * advantages.unsqueeze(1) |
| surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages.unsqueeze(1) |
|
|
| critic_loss = self.MseLoss(state_values.squeeze(), rewards) |
|
|
| critic_loss_post = self.MseLoss(post_state_values.squeeze(), post_rewards) |
|
|
| actor_loss = (-torch.min(surr1, surr2) - 0.001 * dist_entropy).mean() + 0.7*(critic_loss.detach() + critic_loss_post.detach()) |
|
|
| |
| self.optimizer_actor.zero_grad() |
| actor_loss.backward() |
| self.optimizer_actor.step() |
|
|
| |
| self.optimizer_critic.zero_grad() |
| critic_loss.backward() |
| self.optimizer_critic.step() |
|
|
| |
| self.optimizer_critic_post.zero_grad() |
| critic_loss_post.backward() |
| self.optimizer_critic_post.step() |
| |
| self.policy_old.load_state_dict(self.policy.state_dict()) |
| |
| |
| |
| |
| def save(self, checkpoint_path): |
| torch.save(self.policy_old.state_dict(), checkpoint_path) |
| |
| def load(self, checkpoint_path): |
| self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |
| self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |