lsr-lang / models /apm.py
singhanshuman's picture
Upload models/apm.py with huggingface_hub
70e204b verified
raw
history blame contribute delete
846 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
class APM(nn.Module):
"""Action Proposal Module: predicts the action bridging two latent states."""
def __init__(self, z_dim: int = 4, action_dim: int = 4, hidden_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2 * z_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, action_dim),
)
self.z_dim = z_dim
self.action_dim = action_dim
def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor:
return self.net(torch.cat([z_i, z_j], dim=-1))
def apm_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return F.mse_loss(pred, target)