singhanshuman commited on
Commit
70e204b
·
verified ·
1 Parent(s): 1624b35

Upload models/apm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/apm.py +26 -0
models/apm.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class APM(nn.Module):
7
+ """Action Proposal Module: predicts the action bridging two latent states."""
8
+
9
+ def __init__(self, z_dim: int = 4, action_dim: int = 4, hidden_dim: int = 256):
10
+ super().__init__()
11
+ self.net = nn.Sequential(
12
+ nn.Linear(2 * z_dim, hidden_dim),
13
+ nn.ReLU(inplace=True),
14
+ nn.Linear(hidden_dim, hidden_dim),
15
+ nn.ReLU(inplace=True),
16
+ nn.Linear(hidden_dim, action_dim),
17
+ )
18
+ self.z_dim = z_dim
19
+ self.action_dim = action_dim
20
+
21
+ def forward(self, z_i: torch.Tensor, z_j: torch.Tensor) -> torch.Tensor:
22
+ return self.net(torch.cat([z_i, z_j], dim=-1))
23
+
24
+
25
+ def apm_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
26
+ return F.mse_loss(pred, target)