| |
| """ |
| Created on Wed Mar 1 00:43:49 2023 |
| |
| @author: leona |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.init as init |
| from torch.distributions import MultivariateNormal |
| from torch.distributions import Categorical |
|
|
| |
| |
| |
| device = torch.device('cpu') |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class NegReLU(nn.Module): |
| def forward(self, x): |
| return -torch.relu(x) |
|
|
| |
| class RolloutBuffer: |
| def __init__(self): |
| self.actions = [] |
| self.states = [] |
| self.logprobs = [] |
| self.rewards = [] |
| self.state_values = [] |
| self.is_terminals = [] |
| |
| def clear(self,lag): |
| self.actions = self.actions[lag:] |
| self.states = self.states[lag:] |
| self.logprobs = self.logprobs[lag:] |
| self.rewards = self.rewards[lag:] |
| self.state_values = self.state_values[lag:] |
| self.is_terminals = self.is_terminals[lag:] |
|
|
|
|
| class ActorCritic(nn.Module): |
| def __init__(self, state_dim, action_dim, has_continuous_action_space, action_std_init): |
| super(ActorCritic, self).__init__() |
|
|
| self.has_continuous_action_space = has_continuous_action_space |
| |
| if has_continuous_action_space: |
| self.action_dim = action_dim |
| self.action_var = torch.full((action_dim,), action_std_init * action_std_init).to(device) |
| |
| if has_continuous_action_space : |
| self.actor = nn.Sequential( |
| nn.Linear(state_dim, 64), |
| nn.Tanh(), |
| nn.Linear(64, 64), |
| nn.Tanh(), |
| nn.Linear(64, action_dim), |
| nn.Tanh() |
| ) |
| else: |
| |
| self.action_dim = action_dim |
| self.fc1 = nn.Linear(state_dim, 128) |
| self.fc2 = nn.Linear(128, 128) |
| self.actor = nn.Linear(128, self.action_dim.nvec.sum()) |
| |
| |
| self.critic = nn.Sequential( |
| nn.Linear(state_dim, 128), |
| nn.Tanh(), |
| nn.Linear(128, 128), |
| nn.Tanh(), |
| nn.Linear(128, 1) |
| ) |
| |
| def _initialize_actor(self, m): |
| if isinstance(m, nn.Linear): |
| |
| init.kaiming_uniform_(m.weight, nonlinearity='tanh') |
| if m.bias is not None: |
| init.zeros_(m.bias) |
|
|
| def _initialize_critic(self, m): |
| if isinstance(m, nn.Linear): |
| |
| init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| init.zeros_(m.bias) |
| |
| def forward(self, state): |
| raise NotImplementedError |
| |
| |
| |
| def set_action_std(self, new_action_std): |
| if self.has_continuous_action_space: |
| self.action_var = torch.full((self.action_dim,), new_action_std * new_action_std).to(device) |
| else: |
| print("--------------------------------------------------------------------------------------------") |
| print("WARNING : Calling ActorCritic::set_action_std() on discrete action space policy") |
| print("--------------------------------------------------------------------------------------------") |
|
|
|
|
| |
| def act(self, state): |
|
|
| if self.has_continuous_action_space: |
| action_mean = self.actor(state) |
| cov_mat = torch.diag(self.action_var).unsqueeze(dim=0) |
| dist = MultivariateNormal(action_mean, cov_mat) |
| else: |
| |
| x = nn.functional.relu(self.fc2(nn.functional.relu(self.fc1(state)))) |
| logits = self.actor(x) |
| logits_shaped = logits.view(len(self.action_dim.nvec), self.action_dim.nvec.max()) |
| action_probs = nn.functional.softmax(logits_shaped, dim=-1) |
| dist = Categorical(action_probs) |
|
|
| action = dist.sample() |
| action_logprob = dist.log_prob(action) |
| state_val = self.critic(state) |
|
|
| return action.detach(), action_logprob.detach(), state_val.detach() |
| |
| def evaluate(self, state, action): |
|
|
| if self.has_continuous_action_space: |
| action_mean = self.actor(state) |
| |
| action_var = self.action_var.expand_as(action_mean) |
| cov_mat = torch.diag_embed(action_var).to(device) |
| dist = MultivariateNormal(action_mean, cov_mat) |
| |
| |
| if self.action_dim == 1: |
| action = action.reshape(-1, self.action_dim) |
| else: |
| |
| x = nn.functional.relu(self.fc2(nn.functional.relu(self.fc1(state)))) |
| logits = self.actor(x) |
| logits_shaped = logits.view(-1,len(self.action_dim.nvec), self.action_dim.nvec.max()) |
| action_probs = nn.functional.softmax(logits_shaped, dim=-1) |
| dist = Categorical(action_probs) |
| |
| |
| action_logprobs = dist.log_prob(action) |
| dist_entropy = dist.entropy() |
| state_values = self.critic(state) |
| |
| return action_logprobs, state_values, dist_entropy |
|
|
|
|
| class PPO: |
| def __init__(self, state_dim, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, has_continuous_action_space, action_std_init=0.6): |
|
|
| self.has_continuous_action_space = has_continuous_action_space |
|
|
| if has_continuous_action_space: |
| self.action_std = action_std_init |
|
|
| self.gamma = gamma |
| self.eps_clip = eps_clip |
| self.K_epochs = K_epochs |
| |
| self.buffer = RolloutBuffer() |
|
|
| self.policy = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device) |
| self.policy.actor.apply(self.policy._initialize_actor) |
| self.policy.critic.apply(self.policy._initialize_critic) |
| self.optimizer = torch.optim.Adam([ |
| {'params': self.policy.actor.parameters(), 'lr': lr_actor}, |
| {'params': self.policy.critic.parameters(), 'lr': lr_critic} |
| ]) |
|
|
| self.policy_old = ActorCritic(state_dim, action_dim, has_continuous_action_space, action_std_init).to(device) |
| self.policy_old.load_state_dict(self.policy.state_dict()) |
| |
| self.MseLoss = nn.MSELoss() |
|
|
| def set_action_std(self, new_action_std): |
| if self.has_continuous_action_space: |
| self.action_std = new_action_std |
| self.policy.set_action_std(new_action_std) |
| self.policy_old.set_action_std(new_action_std) |
| else: |
| print("--------------------------------------------------------------------------------------------") |
| print("WARNING : Calling PPO::set_action_std() on discrete action space policy") |
| print("--------------------------------------------------------------------------------------------") |
|
|
| def decay_action_std(self, action_std_decay_rate, min_action_std): |
| print("--------------------------------------------------------------------------------------------") |
| if self.has_continuous_action_space: |
| self.action_std = self.action_std - action_std_decay_rate |
| self.action_std = round(self.action_std, 4) |
| if (self.action_std <= min_action_std): |
| self.action_std = min_action_std |
| print("setting actor output action_std to min_action_std : ", self.action_std) |
| else: |
| print("setting actor output action_std to : ", self.action_std) |
| self.set_action_std(self.action_std) |
|
|
| else: |
| print("WARNING : Calling PPO::decay_action_std() on discrete action space policy") |
| print("--------------------------------------------------------------------------------------------") |
|
|
| def select_action(self, state): |
|
|
| if self.has_continuous_action_space: |
| with torch.no_grad(): |
| state = torch.FloatTensor(state).to(device) |
| action, action_logprob, state_val = self.policy_old.act(state) |
|
|
| self.buffer.states.append(state) |
| self.buffer.actions.append(action) |
| self.buffer.logprobs.append(action_logprob) |
| self.buffer.state_values.append(state_val) |
|
|
| return action.detach().cpu().numpy().flatten() |
| else: |
| with torch.no_grad(): |
| state = torch.FloatTensor(state).to(device) |
| action, action_logprob, state_val = self.policy_old.act(state) |
| |
| self.buffer.states.append(state) |
| self.buffer.actions.append(action) |
| self.buffer.logprobs.append(action_logprob) |
| self.buffer.state_values.append(state_val) |
|
|
| return action.cpu().numpy() |
| |
| def update(self): |
| |
| rewards = [] |
| discounted_reward = 0 |
| for reward, is_terminal in zip(reversed(self.buffer.rewards), reversed(self.buffer.is_terminals)): |
| if is_terminal: |
| discounted_reward = 0 |
| discounted_reward = reward + (self.gamma * discounted_reward) |
| rewards.insert(0, discounted_reward) |
| |
| |
| rewards = torch.tensor(rewards, dtype=torch.float32).to(device) |
| |
|
|
| |
| old_states = torch.squeeze(torch.stack(self.buffer.states, dim=0)).detach().to(device) |
| old_actions = torch.squeeze(torch.stack(self.buffer.actions, dim=0)).detach().to(device) |
| old_logprobs = torch.squeeze(torch.stack(self.buffer.logprobs, dim=0)).detach().to(device) |
| old_state_values = torch.squeeze(torch.stack(self.buffer.state_values, dim=0)).detach().to(device) |
|
|
| |
| advantages = rewards.detach() - old_state_values.detach() |
|
|
| |
| for _ in range(self.K_epochs): |
|
|
| |
| logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions) |
|
|
| |
| state_values = torch.squeeze(state_values) |
| |
| |
| ratios = torch.exp(logprobs - old_logprobs.detach()) |
|
|
| |
| surr1 = ratios * advantages.unsqueeze(1) |
| surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages.unsqueeze(1) |
|
|
| |
| loss = -torch.min(surr1, surr2) + 0.7 * self.MseLoss(state_values, rewards) - 0.0010 * dist_entropy |
| |
| |
| self.optimizer.zero_grad() |
| loss.mean().backward() |
| self.optimizer.step() |
| |
| |
| self.policy_old.load_state_dict(self.policy.state_dict()) |
|
|
| |
| |
| |
| def save(self, checkpoint_path): |
| torch.save(self.policy_old.state_dict(), checkpoint_path) |
| |
| def load(self, checkpoint_path): |
| self.policy_old.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |
| self.policy.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage)) |