haofuly's picture
Add files using upload-large-folder tool
cf587f4 verified
raw
history blame
1.92 kB
"""Implementation of additional projectors for additional inputs to the VLA models."""
import torch
import torch.nn as nn
class ProprioProjector(nn.Module):
"""
Projects proprio state inputs into the LLM's embedding space.
"""
def __init__(self, llm_dim: int, proprio_dim: int) -> None:
super().__init__()
self.llm_dim = llm_dim
self.proprio_dim = proprio_dim
self.fc1 = nn.Linear(self.proprio_dim, self.llm_dim, bias=True)
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
self.act_fn1 = nn.GELU()
def forward(self, proprio: torch.Tensor = None) -> torch.Tensor:
# proprio: (bsz, proprio_dim)
projected_features = self.fc1(proprio)
projected_features = self.act_fn1(projected_features)
projected_features = self.fc2(projected_features)
return projected_features
class NoisyActionProjector(nn.Module):
"""
[Diffusion] Projects noisy action inputs into the LLM's embedding space.
Note that since each action is tokenized into 7 tokens in OpenVLA (rather
than having 1 token per action), each noisy action token will have dimension 1
instead of 7.
"""
def __init__(self, llm_dim: int) -> None:
super().__init__()
self.llm_dim = llm_dim
self.action_token_dim = 1
self.fc1 = nn.Linear(self.action_token_dim, self.llm_dim, bias=True)
self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
self.act_fn1 = nn.GELU()
def forward(self, noisy_actions: torch.Tensor = None) -> torch.Tensor:
# noisy_actions: (bsz, num_action_tokens=chunk_len*action_dim, 1)
projected_features = self.fc1(noisy_actions)
projected_features = self.act_fn1(projected_features)
projected_features = self.fc2(projected_features)
return projected_features