Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class APM(nn.Module): | |
| """Action Proposal Module: predicts the action bridging two latent states.""" | |
| def __init__(self, z_dim: int = 4, action_dim: int = 4, hidden_dim: int = 256): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(2 * z_dim, hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(hidden_dim, action_dim), | |
| ) | |
| self.z_dim = z_dim | |
| self.action_dim = action_dim | |
| def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor: | |
| return self.net(torch.cat([z_i, z_j], dim=-1)) | |
| def apm_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| return F.mse_loss(pred, target) | |