| """ |
| Author: Minh Pham-Dinh |
| Created: Jan 26th, 2024 |
| Last Modified: Feb 10th, 2024 |
| Email: mhpham26@colby.edu |
| |
| Description: |
| File containing all models that will be used in Dreamer. |
| |
| The implementation is based on: |
| Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019. |
| [Online]. Available: https://arxiv.org/abs/1912.01603 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| def initialize_weights(m): |
| if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
| nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu") |
| nn.init.constant_(m.bias.data, 0) |
| elif isinstance(m, nn.Linear): |
| nn.init.kaiming_uniform_(m.weight.data) |
| nn.init.constant_(m.bias.data, 0) |
|
|
|
|
| class RSSM(nn.Module): |
| """Reccurent State Space Model (RSSM) |
| The main model that we will use to learn the latent dynamic of the environment |
| """ |
| def __init__(self, stochastic_size, obs_embed_size, deterministic_size, hidden_size, action_size, activation=nn.ELU): |
| super().__init__() |
| self.stochastic_size = stochastic_size |
| self.action_size = action_size |
| self.deterministic_size = deterministic_size |
| self.obs_embed_size = obs_embed_size |
| self.action_size = action_size |
| |
| |
| self.recurrent_linear = nn.Sequential( |
| nn.Linear(stochastic_size + action_size, hidden_size), |
| activation(), |
| ) |
| self.gru_cell = nn.GRUCell(hidden_size, deterministic_size) |
| |
| |
| self.representatio_model = nn.Sequential( |
| nn.Linear(deterministic_size + obs_embed_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, stochastic_size*2) |
| ) |
| |
| |
| self.transition_model = nn.Sequential( |
| nn.Linear(deterministic_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, stochastic_size*2) |
| ) |
| |
| |
| |
| def recurrent(self, stoch_state, action, deterministic): |
| """The recurrent model, calculate the deterministic state given the stochastic state |
| the action, and the prior deterministic |
| |
| Args: |
| a_t-1 (batch_size, action_size): action at time step, cannot be None. |
| s_t-1 (batch_size, stoch_size): stochastic state at time step. Defaults to None. |
| h_t-1 (batch_size, deterministic_size): deterministic at timestep. Defaults to None. |
| |
| Returns: |
| h_t: deterministic at next time step |
| """ |
| |
| |
| x = torch.cat((action, stoch_state), -1) |
| out = self.recurrent_linear(x) |
| out = self.gru_cell(out, deterministic) |
| return out |
|
|
|
|
| def representation(self, embed_obs, deterministic): |
| """Calculate the distribution p of the stochastic state. |
| |
| Args: |
| o_t (batch_size, embeded_obs_size): embedded observation (encoded) |
| h_t (batch_size, deterministic_size): determinstic size |
| |
| Returns: |
| s_t posterior_distribution: distribution of stochastic states |
| s_t posterior: sampled stochastic states |
| """ |
| x = torch.cat((embed_obs, deterministic), -1) |
| out = self.representatio_model(x) |
| mean, std = torch.chunk(out, 2, -1) |
| std = F.softplus(std) + 0.1 |
| |
| post_dist = torch.distributions.Normal(mean, std) |
| post = post_dist.rsample() |
| |
| return post_dist, post |
|
|
|
|
| def transition(self, deterministic): |
| """Calculate the distribution q of the stochastic state. |
| |
| Args: |
| h_t (batch_size, deterministic_size): determinstic size |
| |
| Returns: |
| s_t prior_distribution: distribution of stochastic states |
| s_t prior: sampled stochastic states |
| """ |
| out = self.transition_model(deterministic) |
| mean, std = torch.chunk(out, 2, -1) |
| std = F.softplus(std) + 0.1 |
| |
| prior_dist = torch.distributions.Normal(mean, std) |
| prior = prior_dist.rsample() |
| return prior_dist, prior |
| |
|
|
| class ConvEncoder(nn.Module): |
| def __init__(self, depth=32, input_shape=(3,64,64), activation=nn.ReLU): |
| super().__init__() |
| self.depth = depth |
| self.input_shape = input_shape |
| self.conv_layer = nn.Sequential( |
| nn.Conv2d( |
| in_channels=input_shape[0], |
| out_channels=depth * 1, |
| kernel_size=4, |
| stride=2, |
| padding="valid" |
| ), |
| activation(), |
| nn.Conv2d( |
| in_channels=depth * 1, |
| out_channels=depth * 2, |
| kernel_size=4, |
| stride=2, |
| padding="valid" |
| ), |
| activation(), |
| nn.Conv2d( |
| in_channels=depth * 2, |
| out_channels=depth * 4, |
| kernel_size=4, |
| stride=2, |
| padding="valid" |
| ), |
| activation(), |
| nn.Conv2d( |
| in_channels=depth * 4, |
| out_channels=depth * 8, |
| kernel_size=4, |
| stride=2, |
| padding="valid" |
| ), |
| activation() |
| ) |
| self.conv_layer.apply(initialize_weights) |
| |
| |
| def forward(self, x): |
| batch_shape = x.shape[:-len(self.input_shape)] |
| if not batch_shape: |
| batch_shape = (1, ) |
| |
| x = x.reshape(-1, *self.input_shape) |
| |
| out = self.conv_layer(x) |
| |
| |
| return out.reshape(*batch_shape, -1) |
| |
|
|
| class ConvDecoder(nn.Module): |
| """Decode latent dynamic |
| Also referred to as observation model by the official Dreamer paper |
| |
| """ |
| def __init__(self, stochastic_size, deterministic_size, depth=32, out_shape=(3,64,64), activation=nn.ReLU): |
| super().__init__() |
| self.out_shape = out_shape |
| self.net = nn.Sequential( |
| nn.Linear(deterministic_size + stochastic_size, depth*32), |
| nn.Unflatten(1, (depth * 32, 1)), |
| nn.Unflatten(2, (1, 1)), |
| nn.ConvTranspose2d( |
| depth * 32, |
| depth * 4, |
| kernel_size=5, |
| stride=2, |
| ), |
| activation(), |
| nn.ConvTranspose2d( |
| depth * 4, |
| depth * 2, |
| kernel_size=5, |
| stride=2, |
| ), |
| activation(), |
| nn.ConvTranspose2d( |
| depth * 2, |
| depth * 1, |
| kernel_size=5 + 1, |
| stride=2, |
| ), |
| activation(), |
| nn.ConvTranspose2d( |
| depth * 1, |
| out_shape[0], |
| kernel_size=5+1, |
| stride=2, |
| ), |
| ) |
| self.net.apply(initialize_weights) |
| |
| |
| |
| def forward(self, posterior, deterministic, mps_flatten=False): |
| """take in the stochastic state (posterior) and deterministic to construct the latent state then |
| output reconstructed pixel observation |
| |
| Args: |
| s_t (batch_sz, stoch_size): stochastic state (or posterior) |
| h_t (batch_sz, deterministic_size): deterministic state |
| mps_flatten (boolean): whether to flattening the output for mps device or not. This is because M1 GPU can |
| only support max 4 dimension (stupid af) |
| Returns: |
| o'_t: reconstructed_obs |
| """ |
| x = torch.cat((posterior, deterministic), -1) |
| batch_shape = x.shape[:-1] |
| if not batch_shape: |
| batch_shape = (1, ) |
| |
| x = x.reshape(-1, x.shape[-1]) |
| |
| if mps_flatten: |
| batch_shape = (-1, ) |
| |
| mean = self.net(x).reshape(*batch_shape, *self.out_shape) |
| |
| dist = torch.distributions.Normal(mean, 1) |
| |
| |
| return torch.distributions.Independent(dist, len(self.out_shape)) |
| |
| |
| class RewardNet(nn.Module): |
| """reward prediction model. It take in the stochastic state and the deterministic to construct |
| latent state. It then output the reward prediciton |
| |
| Args: |
| nn (_type_): _description_ |
| """ |
| def __init__(self, input_size, hidden_size, activation=nn.ELU): |
| super().__init__() |
| |
| self.net = nn.Sequential( |
| nn.Linear(input_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, 1) |
| ) |
| |
| |
| def forward(self, stoch_state, deterministic): |
| """take in the stochastic state and deterministic to construct the latent state then |
| output reard prediction |
| |
| Args: |
| s_t (batch_sz, stoch_size): stochastic state (or posterior) |
| h_t (batch_sz, deterministic_size): deterministic state |
| |
| Returns: |
| r_t: rewards |
| """ |
| x = torch.cat((stoch_state, deterministic), -1) |
| batch_shape = x.shape[:-1] |
| if not batch_shape: |
| batch_shape = (1, ) |
|
|
| x = x.reshape(-1, x.shape[-1]) |
| |
| return self.net(x).reshape(*batch_shape, 1) |
| |
|
|
| class ContinuoNet(nn.Module): |
| """continuity prediction model. It take in the stochastic state and the deterministic to construct |
| latent state. It then output the prediction of whether the termination state has been reached |
| |
| Args: |
| nn (_type_): _description_ |
| """ |
| def __init__(self, input_size, hidden_size, activation=nn.ELU): |
| super().__init__() |
| |
| self.net = nn.Sequential( |
| nn.Linear(input_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, 1) |
| ) |
| |
| |
| def forward(self, stoch_state, deterministic): |
| """take in the stochastic state and deterministic to construct the latent state then |
| output reard prediction |
| |
| Args: |
| s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior) |
| h_t deterministic (batch_sz, deterministic_size): deterministic state |
| |
| Returns: |
| dist: Beurnoulli distribution of done |
| """ |
| x = torch.cat((stoch_state, deterministic), -1) |
| batch_shape = x.shape[:-1] |
| if not batch_shape: |
| batch_shape = (1, ) |
|
|
| x = x.reshape(-1, x.shape[-1]) |
| |
| x = self.net(x).reshape(*batch_shape, 1) |
| return x, torch.distributions.Independent(torch.distributions.Bernoulli(logits=x), 1) |
| |
| |
| class Actor(nn.Module): |
| """actor network |
| """ |
| def __init__(self, |
| latent_size, |
| hidden_size, |
| action_size, |
| discrete=True, |
| activation=nn.ELU, |
| min_std=1e-4, |
| init_std=5, |
| mean_scale=5): |
| |
| super().__init__() |
| self.latent_size = latent_size |
| self.hidden_size = hidden_size |
| self.action_size = (action_size if discrete else action_size*2) |
| self.discrete = discrete |
| self.min_std=min_std |
| self.init_std = init_std |
| self.mean_scale = mean_scale |
| |
| self.net = nn.Sequential( |
| nn.Linear(latent_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, self.action_size) |
| ) |
| |
| |
| def forward(self, stoch_state, deterministic): |
| """actor network. get in stochastic state and deterministic state to construct latent state |
| and then use latent state to predict appropriate action |
| |
| Args: |
| s_t stoch_state (batch_sz, stoch_size): stochastic state (or posterior) |
| h_t deterministic (batch_sz, deterministic_size): deterministic state |
| |
| Returns: |
| action distribution. OneHot if discrete, else is tanhNormal |
| """ |
| latent_state = torch.cat((stoch_state, deterministic), -1) |
| x = self.net(latent_state) |
| |
| if self.discrete: |
| |
| dist = torch.distributions.OneHotCategorical(logits=x) |
| action = dist.sample() + dist.probs - dist.probs.detach() |
| else: |
| |
| raw_init_std = np.log(np.exp(self.init_std) - 1) |
| |
| mean, std = torch.chunk(x, 2, -1) |
| mean = self.mean_scale * F.tanh(mean / self.mean_scale) |
| std = F.softplus(std + raw_init_std) + self.min_std |
| |
| dist = torch.distributions.Normal(mean, std) |
| dist = torch.distributions.TransformedDistribution(dist, torch.distributions.TanhTransform()) |
| action = torch.distributions.Independent(dist, 1).rsample() |
|
|
| return action |
| |
| |
| class Critic(nn.Module): |
| """ |
| critic network |
| """ |
| def __init__(self, latent_size, hidden_size, activation=nn.ELU): |
| super().__init__() |
| self.latent_size = latent_size |
| |
| self.net = nn.Sequential( |
| nn.Linear(latent_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, hidden_size), |
| activation(), |
| nn.Linear(hidden_size, 1) |
| ) |
| |
| |
| |
| def forward(self, stoch_state, deterministic): |
| """critic network. get in stochastic state and deterministic state to construct latent state |
| and then use latent state to predict state value |
| |
| Args: |
| s_t stoch_state (batch_sz, seq_len, stoch_size): stochastic state (or posterior) |
| h_t deterministic (batch_sz, seq_len, deterministic_size): deterministic state |
| |
| Returns: |
| state value distribution. |
| """ |
| latent_state = torch.cat((stoch_state, deterministic), -1) |
|
|
| batch_shape = latent_state.shape[:-1] |
| if not batch_shape: |
| batch_shape = (1, ) |
| |
| latent_state = latent_state.reshape(-1, self.latent_size) |
| |
| x = self.net(latent_state) |
| |
| x = x.reshape(*batch_shape, 1) |
| |
| dist = torch.distributions.Normal(x, 1) |
| dist = torch.distributions.Independent(dist, 1) |
| |
| return dist |
| |