Spaces:
Sleeping
Sleeping
File size: 846 Bytes
70e204b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import torch
import torch.nn as nn
import torch.nn.functional as F
class APM(nn.Module):
"""Action Proposal Module: predicts the action bridging two latent states."""
def __init__(self, z_dim: int = 4, action_dim: int = 4, hidden_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2 * z_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, action_dim),
)
self.z_dim = z_dim
self.action_dim = action_dim
def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor:
return self.net(torch.cat([z_i, z_j], dim=-1))
def apm_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.mse_loss(pred, target)
|