File size: 1,922 Bytes
cf587f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | """Implementation of additional projectors for additional inputs to the VLA models."""
import torch
import torch.nn as nn
class ProprioProjector(nn.Module):
"""
Projects proprio state inputs into the LLM's embedding space.
"""
def __init__(self, llm_dim: int, proprio_dim: int) -> None:
super().__init__()
self.llm_dim = llm_dim
self.proprio_dim = proprio_dim
self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True)
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
self.act_fn1 = nn.GELU()
def forward(self, proprio: torch.Tensor = None) -> torch.Tensor:
# proprio: (bsz, proprio_dim)
projected_features = self.fc1(proprio)
projected_features = self.act_fn1(projected_features)
projected_features = self.fc2(projected_features)
return projected_features
class NoisyActionProjector(nn.Module):
"""
[Diffusion] Projects noisy action inputs into the LLM's embedding space.
Note that since each action is tokenized into 7 tokens in OpenVLA (rather
than having 1 token per action), each noisy action token will have dimension 1
instead of 7.
"""
def __init__(self, llm_dim: int) -> None:
super().__init__()
self.llm_dim = llm_dim
self.action_token_dim = 1
self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True)
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
self.act_fn1 = nn.GELU()
def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor:
# noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1)
projected_features = self.fc1(noisy_actions)
projected_features = self.act_fn1(projected_features)
projected_features = self.fc2(projected_features)
return projected_features
|